REvert changes in modeling.
This commit is contained in:
parent
63e72033b7
commit
01e4442ef6
|
@ -151,13 +151,14 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
paged_attention.reshape_and_cache(
|
||||
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
flash_attn.attention(
|
||||
|
|
|
@ -214,8 +214,6 @@ class MistralAttention(torch.nn.Module):
|
|||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
if prefill_cache_indices is not None:
|
||||
kv_to_cache = kv[prefill_cache_indices]
|
||||
else:
|
||||
|
@ -224,6 +222,10 @@ class MistralAttention(torch.nn.Module):
|
|||
paged_attention.reshape_and_cache(
|
||||
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
|
||||
)
|
||||
|
||||
# output tensor
|
||||
attn_output = torch.empty_like(query)
|
||||
|
||||
# Prefill
|
||||
if cu_seqlen_prefill is not None:
|
||||
# flash attention
|
||||
|
@ -235,6 +237,7 @@ class MistralAttention(torch.nn.Module):
|
|||
cu_seqlen_prefill,
|
||||
max_s,
|
||||
self.softmax_scale,
|
||||
window_size_left=self.max_past,
|
||||
)
|
||||
# Decode
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue