From 689b1abbf68cd929f41b72b06cc9e44b266fed53 Mon Sep 17 00:00:00 2001 From: "Wang, Yi" Date: Fri, 9 Aug 2024 00:08:52 +0800 Subject: [PATCH] fix EleutherAI/gpt-neox-20b does not work in tgi (#2346) Signed-off-by: Wang, Yi A --- .../models/custom_modeling/flash_neox_modeling.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b1b03ad7..67237d5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -153,8 +153,16 @@ class FlashNeoxAttention(torch.nn.Module): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + # Compute rotary embeddings on rotary_ndims + query_rot = qkv[:, 0][..., : self.rotary_dim] + query_pass = qkv[:, 0][..., self.rotary_dim :] + key_rot = qkv[:, 1][..., : self.rotary_dim] + key_pass = qkv[:, 1][..., self.rotary_dim :] + # Inplace rotary - self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) + self.rotary_emb(query_rot, key_rot, cos, sin) + qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) + qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)