fix(server): fix flash neox rotary embeddings (#150)

This commit is contained in:
OlivierDehaene 2023-03-30 16:12:23 +02:00 committed by GitHub
parent 610bb1f978
commit 08b7e4a282
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 4 deletions

View File

@ -319,12 +319,12 @@ class FlashNeoxAttention(torch.nn.Module):
layer_past[...] = qkv_rot[:, 1:]
# output
attn_output = torch.empty_like(qkv[:, 0])
attn_output = torch.empty_like(qkv_rot[:, 0])
# flash attention
flash_attn_cuda.fwd(
qkv[:, 0],
qkv[:, 1],
qkv[:, 2],
qkv_rot[:, 0],
qkv_rot[:, 1],
qkv_rot[:, 2],
attn_output,
cu_seqlens,
cu_seqlens,