diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index b1b03ad7..67237d5c 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -153,8 +153,16 @@ class FlashNeoxAttention(torch.nn.Module): qkv = self.query_key_value(hidden_states) qkv = qkv.view(-1, 3, self.num_heads, self.head_size) + # Compute rotary embeddings on rotary_ndims + query_rot = qkv[:, 0][..., : self.rotary_dim] + query_pass = qkv[:, 0][..., self.rotary_dim :] + key_rot = qkv[:, 1][..., : self.rotary_dim] + key_pass = qkv[:, 1][..., self.rotary_dim :] + # Inplace rotary - self.rotary_emb(qkv[:, 0], qkv[:, 1], cos, sin) + self.rotary_emb(query_rot, key_rot, cos, sin) + qkv[:, 0] = torch.cat((query_rot, query_pass), dim=-1) + qkv[:, 1] = torch.cat((key_rot, key_pass), dim=-1) reshape_and_cache(qkv[:, 1], qkv[:, 2], kv_cache[0], kv_cache[1], slots)