diff --git a/server/text_generation_server/layers/attention/cuda.py b/server/text_generation_server/layers/attention/cuda.py index dff742dc..2b898831 100644 --- a/server/text_generation_server/layers/attention/cuda.py +++ b/server/text_generation_server/layers/attention/cuda.py @@ -172,6 +172,10 @@ def paged_attention( try: + is_ampere_or_newer = major >= 8 and minor >= 0 + if not is_ampere_or_newer: + raise ImportError("FlashAttention only supports Ampere GPUs or newer.") + import flash_attn_2_cuda V2 = True diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 3dc24159..ae791ef8 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -484,6 +484,9 @@ def get_model( ) sliding_window = config_dict.get("sliding_window", -1) + if max_input_tokens is not None and max_input_tokens <= sliding_window: + sliding_window = -1 + if ( (sliding_window is not None and sliding_window != -1) and not SUPPORTS_WINDOWING