fix: attempt forward on flash attn2 to check hardware support (#2335)
* fix: attempt forward on flash attn2 to check hardware support * fix: warn window_size_left when using flash attn 1 * fix: prefer version check over test op and avoid window_size_left if not flash attn2 * fix: improve condtional and error message * fix: update sliding window conditional * fix: simplify changes and revert model changes * fix: avoid changing conditional * fix: typo tweak
This commit is contained in:
parent
47447ef017
commit
215ed3ad52
|
@ -172,6 +172,10 @@ def paged_attention(
|
||||||
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
is_ampere_or_newer = major >= 8 and minor >= 0
|
||||||
|
if not is_ampere_or_newer:
|
||||||
|
raise ImportError("FlashAttention only supports Ampere GPUs or newer.")
|
||||||
|
|
||||||
import flash_attn_2_cuda
|
import flash_attn_2_cuda
|
||||||
|
|
||||||
V2 = True
|
V2 = True
|
||||||
|
|
|
@ -484,6 +484,9 @@ def get_model(
|
||||||
)
|
)
|
||||||
sliding_window = config_dict.get("sliding_window", -1)
|
sliding_window = config_dict.get("sliding_window", -1)
|
||||||
|
|
||||||
|
if max_input_tokens is not None and max_input_tokens <= sliding_window:
|
||||||
|
sliding_window = -1
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(sliding_window is not None and sliding_window != -1)
|
(sliding_window is not None and sliding_window != -1)
|
||||||
and not SUPPORTS_WINDOWING
|
and not SUPPORTS_WINDOWING
|
||||||
|
|
Loading…
Reference in New Issue