REvert changes in modeling.
This commit is contained in:
parent
63e72033b7
commit
01e4442ef6
|
@ -151,13 +151,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
# output tensor
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
flash_attn.attention(
|
flash_attn.attention(
|
||||||
|
|
|
@ -214,8 +214,6 @@ class MistralAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||||
|
|
||||||
attn_output = torch.empty_like(query)
|
|
||||||
|
|
||||||
if prefill_cache_indices is not None:
|
if prefill_cache_indices is not None:
|
||||||
kv_to_cache = kv[prefill_cache_indices]
|
kv_to_cache = kv[prefill_cache_indices]
|
||||||
else:
|
else:
|
||||||
|
@ -224,6 +222,10 @@ class MistralAttention(torch.nn.Module):
|
||||||
paged_attention.reshape_and_cache(
|
paged_attention.reshape_and_cache(
|
||||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# output tensor
|
||||||
|
attn_output = torch.empty_like(query)
|
||||||
|
|
||||||
# Prefill
|
# Prefill
|
||||||
if cu_seqlen_prefill is not None:
|
if cu_seqlen_prefill is not None:
|
||||||
# flash attention
|
# flash attention
|
||||||
|
@ -235,6 +237,7 @@ class MistralAttention(torch.nn.Module):
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
max_s,
|
max_s,
|
||||||
self.softmax_scale,
|
self.softmax_scale,
|
||||||
|
window_size_left=self.max_past,
|
||||||
)
|
)
|
||||||
# Decode
|
# Decode
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue