fix: attempt forward on flash attn2 to check hardware support (#2335)

* fix: attempt forward on flash attn2 to check hardware support

* fix: warn window_size_left when using flash attn 1

* fix: prefer version check over test op and avoid window_size_left if not flash attn2

* fix: improve condtional and error message

* fix: update sliding window conditional

* fix: simplify changes and revert model changes

* fix: avoid changing conditional

* fix: typo tweak
This commit is contained in:
drbh 2024-08-05 09:11:40 -04:00 committed by GitHub
parent 47447ef017
commit 215ed3ad52
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 7 additions and 0 deletions

View File

@ -172,6 +172,10 @@ def paged_attention(
try: try:
is_ampere_or_newer = major >= 8 and minor >= 0
if not is_ampere_or_newer:
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
import flash_attn_2_cuda import flash_attn_2_cuda
V2 = True V2 = True

View File

@ -484,6 +484,9 @@ def get_model(
) )
sliding_window = config_dict.get("sliding_window", -1) sliding_window = config_dict.get("sliding_window", -1)
if max_input_tokens is not None and max_input_tokens <= sliding_window:
sliding_window = -1
if ( if (
(sliding_window is not None and sliding_window != -1) (sliding_window is not None and sliding_window != -1)
and not SUPPORTS_WINDOWING and not SUPPORTS_WINDOWING