fix: attempt forward on flash attn2 to check hardware support (#2335)
* fix: attempt forward on flash attn2 to check hardware support * fix: warn window_size_left when using flash attn 1 * fix: prefer version check over test op and avoid window_size_left if not flash attn2 * fix: improve condtional and error message * fix: update sliding window conditional * fix: simplify changes and revert model changes * fix: avoid changing conditional * fix: typo tweak
This commit is contained in:
parent
47447ef017
commit
215ed3ad52
|
@ -172,6 +172,10 @@ def paged_attention(
|
|||
|
||||
|
||||
try:
|
||||
is_ampere_or_newer = major >= 8 and minor >= 0
|
||||
if not is_ampere_or_newer:
|
||||
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
|
||||
|
||||
import flash_attn_2_cuda
|
||||
|
||||
V2 = True
|
||||
|
|
|
@ -484,6 +484,9 @@ def get_model(
|
|||
)
|
||||
sliding_window = config_dict.get("sliding_window", -1)
|
||||
|
||||
if max_input_tokens is not None and max_input_tokens <= sliding_window:
|
||||
sliding_window = -1
|
||||
|
||||
if (
|
||||
(sliding_window is not None and sliding_window != -1)
|
||||
and not SUPPORTS_WINDOWING
|
||||
|
|
Loading…
Reference in New Issue