fix EleutherAI/gpt-neox-20b does not work in tgi (#2346)

Signed-off-by: Wang, Yi A <yi.a.wang@intel.com>
This commit is contained in:
Wang, Yi 2024-08-09 00:08:52 +08:00 committed by GitHub
parent 82d19d7723
commit 689b1abbf6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 9 additions and 1 deletions

View File

@ -153,8 +153,16 @@ class FlashNeoxAttention(torch.nn.Module):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states)
qkv = qkv.view(-1, 3, self.num_heads, self.head_size) qkv = qkv.view(-1, 3, self.num_heads, self.head_size)
# Compute rotary embeddings on rotary_ndims
query_rot = qkv[:, 0][..., : self.rotary_dim]
query_pass = qkv[:, 0][..., self.rotary_dim :]
key_rot = qkv[:, 1][..., : self.rotary_dim]
key_pass = qkv[:, 1][..., self.rotary_dim :]
# Inplace rotary # Inplace rotary
self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) self.rotary_emb(query_rot, key_rot, cos, sin)
qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1)
qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1)
reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)