feat(server): reduce mlp and attn in one op for flash neox (#145)

This commit is contained in:
OlivierDehaene 2023-03-28 16:51:41 +02:00 committed by GitHub
parent f000068944
commit c9bdaa8b73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 102 additions and 104 deletions

View File

@ -1,3 +1,23 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch import torch
import torch.distributed import torch.distributed
@ -16,6 +36,42 @@ import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 6144:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
class FastLinear(nn.Linear): class FastLinear(nn.Linear):
def __init__( def __init__(
self, self,
@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear):
dtype=dtype, dtype=dtype,
) )
def forward(self, input):
return super(TensorParallelColumnLinear, self).forward(input)
class TensorParallelRowLinear(FastLinear): class TensorParallelRowLinear(FastLinear):
def __init__( def __init__(
@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear):
in_features, in_features,
out_features, out_features,
process_group: torch.distributed.ProcessGroup, process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True, bias=True,
device=None, device=None,
dtype=None, dtype=None,
): ):
self.process_group = process_group self.process_group = process_group
self.tp_world_size = process_group.size() self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0 assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size in_features = in_features // self.tp_world_size
@ -88,6 +143,7 @@ class TensorParallelRowLinear(FastLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor: def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input) out = super(TensorParallelRowLinear, self).forward(input)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group) torch.distributed.all_reduce(out, group=self.process_group)
return out return out
@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding):
class FlashNeoxAttention(torch.nn.Module): class FlashNeoxAttention(torch.nn.Module):
def __init__( def __init__(
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None self,
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group=None,
reduce=True,
): ):
super().__init__() super().__init__()
self.num_heads = num_heads self.num_heads = num_heads
@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module):
process_group=process_group, process_group=process_group,
) )
self.dense = TensorParallelRowLinear( self.dense = TensorParallelRowLinear(
hidden_size, hidden_size, hidden_size, process_group=process_group, reduce=reduce
hidden_size,
process_group=process_group,
) )
def shuffle_qkv_dims(self): def shuffle_qkv_dims(self):
@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module):
class FlashMLP(nn.Module): class FlashMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None): def __init__(
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
):
super().__init__() super().__init__()
self.act = ( self.act = (
ACT2FN[act] ACT2FN[act]
@ -330,6 +392,7 @@ class FlashMLP(nn.Module):
intermediate_size, intermediate_size,
hidden_size, hidden_size,
process_group=process_group, process_group=process_group,
reduce=reduce,
) )
self.process_group = process_group self.process_group = process_group
@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module):
): ):
super().__init__() super().__init__()
self.use_parallel_residual = use_parallel_residual self.use_parallel_residual = use_parallel_residual
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps) self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = FlashNeoxAttention( self.attention = FlashNeoxAttention(
num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group,
reduce=not use_parallel_residual,
) )
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group) self.mlp = FlashMLP(
act,
hidden_size,
intermediate_size,
process_group,
reduce=not use_parallel_residual,
)
self.process_group = process_group
def forward( def forward(
self, self,
@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q, cu_seqlens_q,
): ):
if self.use_parallel_residual: if self.use_parallel_residual:
# faster input layer norm ln1_hidden_states, _ = self.input_layernorm(hidden_states)
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.input_layernorm.weight,
self.input_layernorm.bias,
None,
None,
None,
None,
0.0,
self.input_layernorm.eps,
1.0,
0,
None,
False,
False,
)
attn_output = self.attention( attn_output = self.attention(
ln1_hidden_states, ln1_hidden_states,
@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q, cu_seqlens_q,
) )
# faster post attention layer norm ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
None,
None,
None,
None,
0.0,
self.post_attention_layernorm.eps,
1.0,
0,
None,
False,
False,
)
mlp_output = self.mlp(ln2_hidden_states) mlp_output = self.mlp(ln2_hidden_states)
return mlp_output + attn_output + hidden_states, None intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None
else: else:
# faster input layer norm hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.input_layernorm.weight,
self.input_layernorm.bias,
None,
None,
None,
None,
0.0,
self.input_layernorm.eps,
1.0,
0,
None,
False,
False,
)
hidden_states = self.attention( hidden_states = self.attention(
hidden_states, hidden_states,
@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q, cu_seqlens_q,
) )
# faster post attention layer norm hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd( hidden_states, residual
hidden_states,
residual,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
None,
None,
None,
None,
0.0,
self.post_attention_layernorm.eps,
1.0,
0,
None,
False,
False,
) )
mlp_output = self.mlp(hidden_states) mlp_output = self.mlp(hidden_states)
@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
for _ in range(config.num_hidden_layers) for _ in range(config.num_hidden_layers)
] ]
) )
self.final_layer_norm = nn.LayerNorm( self.final_layer_norm = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_eps config.hidden_size, eps=config.layer_norm_eps
) )
@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q, cu_seqlens_q,
) )
# Faster final layer norm hidden_states, _ = self.final_layer_norm(hidden_states, residual)
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
None,
None,
None,
None,
0.0,
self.final_layer_norm.eps,
1.0,
0,
None,
False,
False,
)
return hidden_states, past_key_values return hidden_states, past_key_values