fix(server): fix flash neox rotary embeddings (#150)
This commit is contained in:
parent
610bb1f978
commit
08b7e4a282
|
@ -319,12 +319,12 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
layer_past[...] = qkv_rot[:, 1:]
|
||||
|
||||
# output
|
||||
attn_output = torch.empty_like(qkv[:, 0])
|
||||
attn_output = torch.empty_like(qkv_rot[:, 0])
|
||||
# flash attention
|
||||
flash_attn_cuda.fwd(
|
||||
qkv[:, 0],
|
||||
qkv[:, 1],
|
||||
qkv[:, 2],
|
||||
qkv_rot[:, 0],
|
||||
qkv_rot[:, 1],
|
||||
qkv_rot[:, 2],
|
||||
attn_output,
|
||||
cu_seqlens,
|
||||
cu_seqlens,
|
||||
|
|
Loading…
Reference in New Issue