From 215ed3ad52651f76ca4326713ba9e4e5107323e5 Mon Sep 17 00:00:00 2001 From: drbh Date: Mon, 5 Aug 2024 09:11:40 -0400 Subject: [PATCH] fix: attempt forward on flash attn2 to check hardware support (#2335) * fix: attempt forward on flash attn2 to check hardware support * fix: warn window_size_left when using flash attn 1 * fix: prefer version check over test op and avoid window_size_left if not flash attn2 * fix: improve condtional and error message * fix: update sliding window conditional * fix: simplify changes and revert model changes * fix: avoid changing conditional * fix: typo tweak --- server/text_generation_server/layers/attention/cuda.py | 4 ++++ server/text_generation_server/models/__init__.py | 3 +++ 2 files changed, 7 insertions(+) diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dff742dc..2b898831 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -172,6 +172,10 @@ def paged_attention( try: + is_ampere_or_newer = major >= 8 and minor >= 0 + if not is_ampere_or_newer: + raise ImportError("FlashAttention only supports Ampere GPUs or newer.") + import flash_attn_2_cuda V2 = True diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3dc24159..ae791ef8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -484,6 +484,9 @@ def get_model( ) sliding_window = config_dict.get("sliding_window", -1) + if max_input_tokens is not None and max_input_tokens <= sliding_window: + sliding_window = -1 + if ( (sliding_window is not None and sliding_window != -1) and not SUPPORTS_WINDOWING