From c9bdaa8b734290465b8fb4db4edfc9536ff82346 Mon Sep 17 00:00:00 2001 From: OlivierDehaene Date: Tue, 28 Mar 2023 16:51:41 +0200 Subject: [PATCH] feat(server): reduce mlp and attn in one op for flash neox (#145) --- .../models/flash_neox_modeling.py | 206 +++++++++--------- 1 file changed, 102 insertions(+), 104 deletions(-) diff --git a/server/text_generation_server/models/flash_neox_modeling.py b/server/text_generation_server/models/flash_neox_modeling.py index 2e638d77..ac07aa98 100644 --- a/server/text_generation_server/models/flash_neox_modeling.py +++ b/server/text_generation_server/models/flash_neox_modeling.py @@ -1,3 +1,23 @@ +# coding=utf-8 +# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. +# +# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX +# and OPT implementations in this library. It has been modified from its +# original forms to accommodate minor architectural differences compared +# to GPT-NeoX and OPT used by the Meta AI team that trained the model. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import torch import torch.distributed @@ -16,6 +36,42 @@ import dropout_layer_norm from flash_attn.layers.rotary import RotaryEmbedding +class FastLayerNorm(nn.LayerNorm): + def forward(self, hidden_states, residual=None): + if hidden_states.shape[-1] > 6144: + if residual is not None: + hidden_states += residual + residual = hidden_states + + return super(FastLayerNorm, self).forward(hidden_states), residual + else: + ( + normed_hidden_states, + residual, + *rest, + ) = dropout_layer_norm.dropout_add_ln_fwd( + hidden_states, + residual, + self.weight, + self.bias, + None, + None, + None, + None, + 0.0, + self.eps, + 1.0, + 0, + None, + False, + False, + ) + if residual is None: + residual = hidden_states + + return normed_hidden_states, residual + + class FastLinear(nn.Linear): def __init__( self, @@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear): dtype=dtype, ) - def forward(self, input): - return super(TensorParallelColumnLinear, self).forward(input) - class TensorParallelRowLinear(FastLinear): def __init__( @@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear): in_features, out_features, process_group: torch.distributed.ProcessGroup, + reduce=True, bias=True, device=None, dtype=None, ): self.process_group = process_group self.tp_world_size = process_group.size() + self.reduce = reduce assert in_features % self.tp_world_size == 0 in_features = in_features // self.tp_world_size @@ -88,7 +143,8 @@ class TensorParallelRowLinear(FastLinear): def forward(self, input: torch.Tensor) -> torch.Tensor: out = super(TensorParallelRowLinear, self).forward(input) - torch.distributed.all_reduce(out, group=self.process_group) + if self.reduce: + torch.distributed.all_reduce(out, group=self.process_group) return out @@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding): class FlashNeoxAttention(torch.nn.Module): def __init__( - self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None + self, + num_heads, + hidden_size, + rotary_pct, + rotary_emb_base, + process_group=None, + reduce=True, ): super().__init__() self.num_heads = num_heads @@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module): process_group=process_group, ) self.dense = TensorParallelRowLinear( - hidden_size, - hidden_size, - process_group=process_group, + hidden_size, hidden_size, process_group=process_group, reduce=reduce ) def shuffle_qkv_dims(self): @@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module): class FlashMLP(nn.Module): - def __init__(self, act, hidden_size, intermediate_size, process_group=None): + def __init__( + self, act, hidden_size, intermediate_size, process_group=None, reduce=True + ): super().__init__() self.act = ( ACT2FN[act] @@ -330,6 +392,7 @@ class FlashMLP(nn.Module): intermediate_size, hidden_size, process_group=process_group, + reduce=reduce, ) self.process_group = process_group @@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module): ): super().__init__() self.use_parallel_residual = use_parallel_residual - self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) - self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) + self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) + self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps) self.attention = FlashNeoxAttention( - num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group + num_heads, + hidden_size, + rotary_pct, + rotary_emb_base, + process_group, + reduce=not use_parallel_residual, ) - self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) + self.mlp = FlashMLP( + act, + hidden_size, + intermediate_size, + process_group, + reduce=not use_parallel_residual, + ) + self.process_group = process_group def forward( self, @@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module): cu_seqlens_q, ): if self.use_parallel_residual: - # faster input layer norm - ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - None, - self.input_layernorm.weight, - self.input_layernorm.bias, - None, - None, - None, - None, - 0.0, - self.input_layernorm.eps, - 1.0, - 0, - None, - False, - False, - ) + ln1_hidden_states, _ = self.input_layernorm(hidden_states) attn_output = self.attention( ln1_hidden_states, @@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module): cu_seqlens_q, ) - # faster post attention layer norm - ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - None, - self.post_attention_layernorm.weight, - self.post_attention_layernorm.bias, - None, - None, - None, - None, - 0.0, - self.post_attention_layernorm.eps, - 1.0, - 0, - None, - False, - False, - ) + ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states) mlp_output = self.mlp(ln2_hidden_states) - return mlp_output + attn_output + hidden_states, None + intermediate = mlp_output + attn_output + + # Only reduce once and after the addition instead of once per layer + if self.process_group is not None: + torch.distributed.all_reduce(intermediate, group=self.process_group) + + return intermediate + hidden_states, None else: - # faster input layer norm - hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.input_layernorm.weight, - self.input_layernorm.bias, - None, - None, - None, - None, - 0.0, - self.input_layernorm.eps, - 1.0, - 0, - None, - False, - False, - ) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.attention( hidden_states, @@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module): cu_seqlens_q, ) - # faster post attention layer norm - hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.post_attention_layernorm.weight, - self.post_attention_layernorm.bias, - None, - None, - None, - None, - 0.0, - self.post_attention_layernorm.eps, - 1.0, - 0, - None, - False, - False, + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual ) mlp_output = self.mlp(hidden_states) @@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): for _ in range(config.num_hidden_layers) ] ) - self.final_layer_norm = nn.LayerNorm( + self.final_layer_norm = FastLayerNorm( config.hidden_size, eps=config.layer_norm_eps ) @@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): cu_seqlens_q, ) - # Faster final layer norm - hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd( - hidden_states, - residual, - self.final_layer_norm.weight, - self.final_layer_norm.bias, - None, - None, - None, - None, - 0.0, - self.final_layer_norm.eps, - 1.0, - 0, - None, - False, - False, - ) + hidden_states, _ = self.final_layer_norm(hidden_states, residual) return hidden_states, past_key_values