diff --git a/server/text_generation_server/layers/attention/rocm.py b/server/text_generation_server/layers/attention/rocm.py index 2d3601c8..535810aa 100644 --- a/server/text_generation_server/layers/attention/rocm.py +++ b/server/text_generation_server/layers/attention/rocm.py @@ -126,40 +126,34 @@ if ENGINE != "triton": import flash_attn_2_cuda logger.info("ROCm: using Flash Attention 2 Composable Kernel implementation.") - except ImportError: - try: - import flash_attn_cuda + except ImportError as e: + if major >= 8: + architecture_suffix = f"-{SYSTEM}" + raise ImportError( + "Flash Attention V2 is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" + ) + elif is_sm75: + raise ImportError( + "Flash Attention is not installed.\n" + "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " + "or install flash attention with `cd server && make install install-flash-attention`" + ) from e + else: - ENGINE = "v1" - logger.info("ROCm: using Flash Attention 1") - except ImportError as e: - if major >= 8: - architecture_suffix = f"-{SYSTEM}" - raise ImportError( - "Flash Attention V2 is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - f"or install flash attention v2 with `cd server && make install install-flash-attention-v2{architecture_suffix}`" - ) - elif is_sm75: - raise ImportError( - "Flash Attention is not installed.\n" - "Use the official Docker image (ghcr.io/huggingface/text-generation-inference:latest) " - "or install flash attention with `cd server && make install install-flash-attention`" - ) from e - else: - - for idx in range(torch.cuda.device_count()): - name = torch.cuda.get_device_name(idx) - if "MI210" not in name and "MI250" not in name: - raise ImportError( - f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" - ) - raise ImportError( - f"AMD GPU with ROCm capability {major} {minor} is not supported" - ) from e + for idx in range(torch.cuda.device_count()): + name = torch.cuda.get_device_name(idx) + if "MI210" not in name and "MI250" not in name: + raise ImportError( + f"AMD GPU {torch.cuda.get_device_name(idx)} does not support flash-attention" + ) + raise ImportError( + f"AMD GPU with ROCm capability {major} {minor} is not supported" + ) from e -SUPPORTS_WINDOWING = ENGINE != "v1" +SUPPORTS_WINDOWING = False if ENGINE == "ck": def attention( @@ -186,17 +180,12 @@ if ENGINE == "ck": out, cu_seqlens, cu_seqlens, - None, - None, - None, max_s, max_s, 0.0, softmax_scale, False, causal, - window_size_left, - 0, False, None, ) @@ -234,62 +223,4 @@ elif ENGINE == "triton": return output else: - - def attention( - q, - k, - v, - out, - cu_seqlens, - max_s, - softmax_scale, - window_size_left=-1, - ): - if window_size_left != -1: - raise NotImplementedError( - "window_size_left is only available with flash attn v2" - ) - - # Flash attention v1 requires q, k and v to have the same number of heads - if k.shape[1] != q.shape[1]: - # MQA expand - if k.shape[1] == 1: - k = k.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = k.shape - k = ( - k.unsqueeze(2) - .expand(-1, -1, q.shape[1] // k.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - if v.shape[1] != q.shape[1]: - # MQA expand - if v.shape[1] == 1: - v = v.expand(-1, q.shape[1], -1) - # Grouped attention reshape - else: - original_shape = v.shape - v = ( - v.unsqueeze(2) - .expand(-1, -1, q.shape[1] // v.shape[1], -1) - .reshape(original_shape[0], -1, original_shape[2]) - ) - - return flash_attn_cuda.fwd( - q, - k, - v, - out, - cu_seqlens, - cu_seqlens, - max_s, - max_s, - 0.0, - softmax_scale, - False, - True, - False, - 0, - None, - ) + raise RuntimeError(f"Unknown attention engine {ENGINE}")