fix: max_past default value must be -1, not 0 (#1348)
This commit is contained in:
parent
9b78a6eee3
commit
37555cf4e8
|
@ -149,7 +149,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = (
|
||||||
config.sliding_window if config.sliding_window is not None else 0
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
)
|
)
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
|
@ -204,7 +204,7 @@ class MixtralAttention(torch.nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.max_past = (
|
self.max_past = (
|
||||||
config.sliding_window if config.sliding_window is not None else 0
|
config.sliding_window if config.sliding_window is not None else -1
|
||||||
)
|
)
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.hidden_size = config.hidden_size
|
self.hidden_size = config.hidden_size
|
||||||
|
|
|
@ -72,6 +72,9 @@ def attention(
|
||||||
softmax_scale,
|
softmax_scale,
|
||||||
window_size_left=-1,
|
window_size_left=-1,
|
||||||
):
|
):
|
||||||
|
if window_size_left <= 0 and window_size_left != -1:
|
||||||
|
raise ValueError("`window_size_left` must be > 0 or -1")
|
||||||
|
|
||||||
if HAS_FLASH_ATTN_V2_CUDA:
|
if HAS_FLASH_ATTN_V2_CUDA:
|
||||||
return flash_attn_2_cuda.varlen_fwd(
|
return flash_attn_2_cuda.varlen_fwd(
|
||||||
q,
|
q,
|
||||||
|
|
Loading…
Reference in New Issue