diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index c6790254..432d20df 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -396,36 +396,38 @@ jobs: label: ${{ needs.start-runner.outputs.label }} ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} - integration-tests-rocm: - concurrency: - group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} - cancel-in-progress: true - needs: - - start-runner - - build-and-push-image - - integration-tests - - build-and-push-image-rocm - - stop-runner - runs-on: [self-hosted, docker-gpu, amd-gpu, multi-gpu, mi300] - container: - image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm - options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache - env: - DOCKER_VOLUME: /cache - steps: - - name: ROCM-SMI - run: | - rocm-smi - - name: ROCM-INFO - run: | - rocminfo | grep "Agent" -A 14 - - name: Show ROCR environment - run: | - echo "ROCR: $ROCR_VISIBLE_DEVICES" - - name: Install - run: | - make install-integration-tests - - name: Run tests - run: | - export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} - pytest -s -vv integration-tests + # TODO: Move this to `build_amd.yml` (and `build_nvidia.yml`) + + # integration-tests-rocm: + # concurrency: + # group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} + # cancel-in-progress: true + # needs: + # - start-runner + # - build-and-push-image + # - integration-tests + # - build-and-push-image-rocm + # - stop-runner + # runs-on: [self-hosted, amd-gpu, multi-gpu, mi300] + # container: + # image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm + # options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache + # env: + # DOCKER_VOLUME: /cache + # steps: + # - name: ROCM-SMI + # run: | + # rocm-smi + # - name: ROCM-INFO + # run: | + # rocminfo | grep "Agent" -A 14 + # - name: Show ROCR environment + # run: | + # echo "ROCR: $ROCR_VISIBLE_DEVICES" + # - name: Install + # run: | + # make install-integration-tests + # - name: Run tests + # run: | + # export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} + # pytest -s -vv integration-tests diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 8878ad15..9e5676f5 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -79,12 +79,15 @@ try: from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_dbrx import FlashDbrx - from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA - + from text_generation_server.utils.flash_attn import ( + HAS_FLASH_ATTN_V2_CUDA, + HAS_FLASH_ATTN_V2_ROCM, + ) except ImportError as e: logger.warning(f"Could not import Flash Attention enabled models: {e}") FLASH_ATTENTION = False HAS_FLASH_ATTN_V2_CUDA = False + HAS_FLASH_ATTN_V2_ROCM = False if FLASH_ATTENTION: __all__.append(FlashGPT2) @@ -539,8 +542,10 @@ def get_model( if model_type == "mistral": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashMistral( model_id, revision, @@ -564,8 +569,10 @@ def get_model( if model_type == "mixtral": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashMixtral( model_id, revision, @@ -589,8 +596,10 @@ def get_model( if model_type == "starcoder2": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashStarcoder2( model_id, revision, @@ -615,8 +624,10 @@ def get_model( if model_type == "qwen2": sliding_window = config_dict.get("sliding_window", -1) if ( - (sliding_window is None or sliding_window == -1) and FLASH_ATTENTION - ) or HAS_FLASH_ATTN_V2_CUDA: + ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION) + or HAS_FLASH_ATTN_V2_CUDA + or HAS_FLASH_ATTN_V2_ROCM + ): return FlashQwen2( model_id, revision, diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 47758d30..6e23aa2b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -230,11 +230,15 @@ class LlamaMLP(nn.Module): config.intermediate_size // weights.process_group.size() ) + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + def forward(self, hidden_states): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and not self.quantize ): out = torch.empty( hidden_states.shape[0], diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 21edc79e..ef3777da 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -290,11 +290,15 @@ class MistralMLP(nn.Module): config.intermediate_size // weights.process_group.size() ) + # TODO: This is a hotfix to be removed & properly refactored. + self.quantize = config.quantize + def forward(self, hidden_states): if ( SYSTEM == "rocm" and self.hidden_act == "silu" and hidden_states.shape[0] == 1 + and not self.quantize ): out = torch.empty( hidden_states.shape[0], diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index 333efe33..45ddd856 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -890,6 +890,9 @@ class FlashCausalLM(Model): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) kv_cache = get_cache_manager().kv_cache + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -899,7 +902,7 @@ class FlashCausalLM(Model): ), kv_cache=get_cache_manager().kv_cache, block_tables=None, - input_lengths=None, + input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None, diff --git a/server/text_generation_server/models/flash_mistral.py b/server/text_generation_server/models/flash_mistral.py index 30ae95c9..e6125e29 100644 --- a/server/text_generation_server/models/flash_mistral.py +++ b/server/text_generation_server/models/flash_mistral.py @@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM): slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) kv_cache = get_cache_manager().kv_cache + # Dummy value, some models (starcoder2) don't accept `None`. + input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) + # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. self.model.forward( input_ids=input_ids, @@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM): ), kv_cache=get_cache_manager().kv_cache, block_tables=None, - input_lengths=None, + input_lengths=input_lengths, slots=slots, max_s=seqlen, lm_head_indices=None,