Fix TGI issues with ROCm (#1921)
Not all models were tested in https://github.com/huggingface/text-generation-inference/pull/1764. Fixing some more issues (notably starcoder2) here, the full CI will come shortly once we split `build.yml` in two
This commit is contained in:
parent
b5f1c9de06
commit
5dad0c0b29
|
@ -396,36 +396,38 @@ jobs:
|
|||
label: ${{ needs.start-runner.outputs.label }}
|
||||
ec2-instance-id: ${{ needs.start-runner.outputs.ec2-instance-id }}
|
||||
|
||||
integration-tests-rocm:
|
||||
concurrency:
|
||||
group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
needs:
|
||||
- start-runner
|
||||
- build-and-push-image
|
||||
- integration-tests
|
||||
- build-and-push-image-rocm
|
||||
- stop-runner
|
||||
runs-on: [self-hosted, docker-gpu, amd-gpu, multi-gpu, mi300]
|
||||
container:
|
||||
image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm
|
||||
options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache
|
||||
env:
|
||||
DOCKER_VOLUME: /cache
|
||||
steps:
|
||||
- name: ROCM-SMI
|
||||
run: |
|
||||
rocm-smi
|
||||
- name: ROCM-INFO
|
||||
run: |
|
||||
rocminfo | grep "Agent" -A 14
|
||||
- name: Show ROCR environment
|
||||
run: |
|
||||
echo "ROCR: $ROCR_VISIBLE_DEVICES"
|
||||
- name: Install
|
||||
run: |
|
||||
make install-integration-tests
|
||||
- name: Run tests
|
||||
run: |
|
||||
export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
pytest -s -vv integration-tests
|
||||
# TODO: Move this to `build_amd.yml` (and `build_nvidia.yml`)
|
||||
|
||||
# integration-tests-rocm:
|
||||
# concurrency:
|
||||
# group: ${{ github.workflow }}-${{ github.job }}-${{ github.head_ref || github.run_id }}
|
||||
# cancel-in-progress: true
|
||||
# needs:
|
||||
# - start-runner
|
||||
# - build-and-push-image
|
||||
# - integration-tests
|
||||
# - build-and-push-image-rocm
|
||||
# - stop-runner
|
||||
# runs-on: [self-hosted, amd-gpu, multi-gpu, mi300]
|
||||
# container:
|
||||
# image: registry.internal.huggingface.tech/api-inference/community/text-generation-inference:sha-${{ needs.build-and-push-image-rocm.outputs.short_sha }}-rocm
|
||||
# options: --device /dev/kfd --device /dev/dri --env ROCR_VISIBLE_DEVICES --shm-size "16gb" --ipc host -v /mnt/cache/.cache/huggingface:/cache
|
||||
# env:
|
||||
# DOCKER_VOLUME: /cache
|
||||
# steps:
|
||||
# - name: ROCM-SMI
|
||||
# run: |
|
||||
# rocm-smi
|
||||
# - name: ROCM-INFO
|
||||
# run: |
|
||||
# rocminfo | grep "Agent" -A 14
|
||||
# - name: Show ROCR environment
|
||||
# run: |
|
||||
# echo "ROCR: $ROCR_VISIBLE_DEVICES"
|
||||
# - name: Install
|
||||
# run: |
|
||||
# make install-integration-tests
|
||||
# - name: Run tests
|
||||
# run: |
|
||||
# export HUGGING_FACE_HUB_TOKEN=${{ secrets.HUGGING_FACE_HUB_TOKEN }}
|
||||
# pytest -s -vv integration-tests
|
||||
|
|
|
@ -79,12 +79,15 @@ try:
|
|||
from text_generation_server.models.flash_phi import FlashPhi
|
||||
from text_generation_server.models.flash_starcoder2 import FlashStarcoder2
|
||||
from text_generation_server.models.flash_dbrx import FlashDbrx
|
||||
from text_generation_server.utils.flash_attn import HAS_FLASH_ATTN_V2_CUDA
|
||||
|
||||
from text_generation_server.utils.flash_attn import (
|
||||
HAS_FLASH_ATTN_V2_CUDA,
|
||||
HAS_FLASH_ATTN_V2_ROCM,
|
||||
)
|
||||
except ImportError as e:
|
||||
logger.warning(f"Could not import Flash Attention enabled models: {e}")
|
||||
FLASH_ATTENTION = False
|
||||
HAS_FLASH_ATTN_V2_CUDA = False
|
||||
HAS_FLASH_ATTN_V2_ROCM = False
|
||||
|
||||
if FLASH_ATTENTION:
|
||||
__all__.append(FlashGPT2)
|
||||
|
@ -539,8 +542,10 @@ def get_model(
|
|||
if model_type == "mistral":
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
or HAS_FLASH_ATTN_V2_CUDA
|
||||
or HAS_FLASH_ATTN_V2_ROCM
|
||||
):
|
||||
return FlashMistral(
|
||||
model_id,
|
||||
revision,
|
||||
|
@ -564,8 +569,10 @@ def get_model(
|
|||
if model_type == "mixtral":
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
or HAS_FLASH_ATTN_V2_CUDA
|
||||
or HAS_FLASH_ATTN_V2_ROCM
|
||||
):
|
||||
return FlashMixtral(
|
||||
model_id,
|
||||
revision,
|
||||
|
@ -589,8 +596,10 @@ def get_model(
|
|||
if model_type == "starcoder2":
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
or HAS_FLASH_ATTN_V2_CUDA
|
||||
or HAS_FLASH_ATTN_V2_ROCM
|
||||
):
|
||||
return FlashStarcoder2(
|
||||
model_id,
|
||||
revision,
|
||||
|
@ -615,8 +624,10 @@ def get_model(
|
|||
if model_type == "qwen2":
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
if (
|
||||
(sliding_window is None or sliding_window == -1) and FLASH_ATTENTION
|
||||
) or HAS_FLASH_ATTN_V2_CUDA:
|
||||
((sliding_window is None or sliding_window == -1) and FLASH_ATTENTION)
|
||||
or HAS_FLASH_ATTN_V2_CUDA
|
||||
or HAS_FLASH_ATTN_V2_ROCM
|
||||
):
|
||||
return FlashQwen2(
|
||||
model_id,
|
||||
revision,
|
||||
|
|
|
@ -230,11 +230,15 @@ class LlamaMLP(nn.Module):
|
|||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
|
|
|
@ -290,11 +290,15 @@ class MistralMLP(nn.Module):
|
|||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
||||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
and hidden_states.shape[0] == 1
|
||||
and not self.quantize
|
||||
):
|
||||
out = torch.empty(
|
||||
hidden_states.shape[0],
|
||||
|
|
|
@ -890,6 +890,9 @@ class FlashCausalLM(Model):
|
|||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -899,7 +902,7 @@ class FlashCausalLM(Model):
|
|||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=None,
|
||||
input_lengths=input_lengths,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
|
|
|
@ -397,6 +397,9 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
slots = torch.arange(seqlen, dtype=torch.int64, device=self.device)
|
||||
kv_cache = get_cache_manager().kv_cache
|
||||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -406,7 +409,7 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
),
|
||||
kv_cache=get_cache_manager().kv_cache,
|
||||
block_tables=None,
|
||||
input_lengths=None,
|
||||
input_lengths=input_lengths,
|
||||
slots=slots,
|
||||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
|
|
Loading…
Reference in New Issue