REvert changes in modeling.

This commit is contained in:
Nicolas Patry 2024-05-24 14:18:00 +00:00
parent 63e72033b7
commit 01e4442ef6
2 changed files with 9 additions and 5 deletions

View File

@ -151,13 +151,14 @@ class FlashLlamaAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
# output tensor
attn_output = torch.empty_like(query)
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots kv[:, 0], kv[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
flash_attn.attention( flash_attn.attention(

View File

@ -214,8 +214,6 @@ class MistralAttention(torch.nn.Module):
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin) self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
attn_output = torch.empty_like(query)
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
kv_to_cache = kv[prefill_cache_indices] kv_to_cache = kv[prefill_cache_indices]
else: else:
@ -224,6 +222,10 @@ class MistralAttention(torch.nn.Module):
paged_attention.reshape_and_cache( paged_attention.reshape_and_cache(
kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots kv_to_cache[:, 0], kv_to_cache[:, 1], kv_cache[0], kv_cache[1], slots
) )
# output tensor
attn_output = torch.empty_like(query)
# Prefill # Prefill
if cu_seqlen_prefill is not None: if cu_seqlen_prefill is not None:
# flash attention # flash attention
@ -235,6 +237,7 @@ class MistralAttention(torch.nn.Module):
cu_seqlen_prefill, cu_seqlen_prefill,
max_s, max_s,
self.softmax_scale, self.softmax_scale,
window_size_left=self.max_past,
) )
# Decode # Decode
else: else: