Fix TGI issues with ROCm (#1921)

Not all models were tested in
https://github.com/huggingface/text-generation-inference/pull/1764.

Fixing some more issues (notably starcoder2) here, the full CI will come
shortly once we split `build.yml` in two
This commit is contained in:
fxmarty 2024-05-17 19:50:52 +02:00 committed by GitHub
parent b5f1c9de06
commit 5dad0c0b29
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
6 changed files with 72 additions and 45 deletions

View File

@ -396,36 +396,38 @@ jobs:
label: ${{ needs.start-runner.outputs.label }} label: ${{ needs.start-runner.outputs.label }}
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }} ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
integration-tests-rocm: # TODO: Move this to `build_amd.yml` (and `build_nvidia.yml`)
concurrency:
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }} # integration-tests-rocm:
cancel-in-progress: true # concurrency:
needs: # group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
- start-runner # cancel-in-progress: true
- build-and-push-image # needs:
- integration-tests # - start-runner
- build-and-push-image-rocm # - build-and-push-image
- stop-runner # - integration-tests
runs-on: [self-hosted, docker-gpu, amd-gpu, multi-gpu, mi300] # - build-and-push-image-rocm
container: # - stop-runner
image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm # runs-on: [self-hosted, amd-gpu, multi-gpu, mi300]
options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache # container:
env: # image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm
DOCKER_VOLUME: /cache # options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache
steps: # env:
- name: ROCM-SMI # DOCKER_VOLUME: /cache
run: | # steps:
rocm-smi # - name: ROCM-SMI
- name: ROCM-INFO # run: |
run: | # rocm-smi
rocminfo | grep "Agent" -A 14 # - name: ROCM-INFO
- name: Show ROCR environment # run: |
run: | # rocminfo | grep "Agent" -A 14
echo "ROCR: $ROCR_VISIBLE_DEVICES" # - name: Show ROCR environment
- name: Install # run: |
run: | # echo "ROCR: $ROCR_VISIBLE_DEVICES"
make install-integration-tests # - name: Install
- name: Run tests # run: |
run: | # make install-integration-tests
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }} # - name: Run tests
pytest -s -vv integration-tests # run: |
# export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
# pytest -s -vv integration-tests

View File

@ -79,12 +79,15 @@ try:
from text_generation_server.models.flash_phi import FlashPhi from text_generation_server.models.flash_phi import FlashPhi
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2 from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
from text_generation_server.models.flash_dbrx import FlashDbrx from text_generation_server.models.flash_dbrx import FlashDbrx
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA from text_generation_server.utils.flash_attn import (
HAS_FLASH_ATTN_V2_CUDA,
HAS_FLASH_ATTN_V2_ROCM,
)
except ImportError as e: except ImportError as e:
logger.warning(f"Could not import Flash Attention enabled models: {e}") logger.warning(f"Could not import Flash Attention enabled models: {e}")
FLASH_ATTENTION = False FLASH_ATTENTION = False
HAS_FLASH_ATTN_V2_CUDA = False HAS_FLASH_ATTN_V2_CUDA = False
HAS_FLASH_ATTN_V2_ROCM = False
if FLASH_ATTENTION: if FLASH_ATTENTION:
__all__.append(FlashGPT2) __all__.append(FlashGPT2)
@ -539,8 +542,10 @@ def get_model(
if model_type == "mistral": if model_type == "mistral":
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if ( if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
) or HAS_FLASH_ATTN_V2_CUDA: or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashMistral( return FlashMistral(
model_id, model_id,
revision, revision,
@ -564,8 +569,10 @@ def get_model(
if model_type == "mixtral": if model_type == "mixtral":
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if ( if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
) or HAS_FLASH_ATTN_V2_CUDA: or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashMixtral( return FlashMixtral(
model_id, model_id,
revision, revision,
@ -589,8 +596,10 @@ def get_model(
if model_type == "starcoder2": if model_type == "starcoder2":
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if ( if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
) or HAS_FLASH_ATTN_V2_CUDA: or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashStarcoder2( return FlashStarcoder2(
model_id, model_id,
revision, revision,
@ -615,8 +624,10 @@ def get_model(
if model_type == "qwen2": if model_type == "qwen2":
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if ( if (
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION ((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
) or HAS_FLASH_ATTN_V2_CUDA: or HAS_FLASH_ATTN_V2_CUDA
or HAS_FLASH_ATTN_V2_ROCM
):
return FlashQwen2( return FlashQwen2(
model_id, model_id,
revision, revision,

View File

@ -230,11 +230,15 @@ class LlamaMLP(nn.Module):
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and not self.quantize
): ):
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],

View File

@ -290,11 +290,15 @@ class MistralMLP(nn.Module):
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
# TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
and hidden_states.shape[0] == 1 and hidden_states.shape[0] == 1
and not self.quantize
): ):
out = torch.empty( out = torch.empty(
hidden_states.shape[0], hidden_states.shape[0],

View File

@ -890,6 +890,9 @@ class FlashCausalLM(Model):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -899,7 +902,7 @@ class FlashCausalLM(Model):
), ),
kv_cache=get_cache_manager().kv_cache, kv_cache=get_cache_manager().kv_cache,
block_tables=None, block_tables=None,
input_lengths=None, input_lengths=input_lengths,
slots=slots, slots=slots,
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,

View File

@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM):
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device) slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
kv_cache = get_cache_manager().kv_cache kv_cache = get_cache_manager().kv_cache
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
), ),
kv_cache=get_cache_manager().kv_cache, kv_cache=get_cache_manager().kv_cache,
block_tables=None, block_tables=None,
input_lengths=None, input_lengths=input_lengths,
slots=slots, slots=slots,
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,