fix EleutherAI/gpt-neox-20b does not work in tgi (#2346)
Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
parent
82d19d7723
commit
689b1abbf6
|
@ -153,8 +153,16 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states)
|
||||||
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
|
||||||
|
|
||||||
|
# Compute rotary embeddings on rotary_ndims
|
||||||
|
query_rot = qkv[:, 0][..., : self.rotary_dim]
|
||||||
|
query_pass = qkv[:, 0][..., self.rotary_dim :]
|
||||||
|
key_rot = qkv[:, 1][..., : self.rotary_dim]
|
||||||
|
key_pass = qkv[:, 1][..., self.rotary_dim :]
|
||||||
|
|
||||||
# Inplace rotary
|
# Inplace rotary
|
||||||
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin)
|
self.rotary_emb(query_rot, key_rot, cos, sin)
|
||||||
|
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
|
||||||
|
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
|
||||||
|
|
||||||
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue