feat(server): reduce mlp and attn in one op for flash neox (#145)

This commit is contained in:
OlivierDehaene 2023-03-28 16:51:41 +02:00 committed by GitHub
parent f000068944
commit c9bdaa8b73
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 102 additions and 104 deletions

View File

@ -1,3 +1,23 @@
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.distributed
@ -16,6 +36,42 @@ import dropout_layer_norm
from flash_attn.layers.rotary import RotaryEmbedding
class FastLayerNorm(nn.LayerNorm):
def forward(self, hidden_states, residual=None):
if hidden_states.shape[-1] > 6144:
if residual is not None:
hidden_states += residual
residual = hidden_states
return super(FastLayerNorm, self).forward(hidden_states), residual
else:
(
normed_hidden_states,
residual,
*rest,
) = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.weight,
self.bias,
None,
None,
None,
None,
0.0,
self.eps,
1.0,
0,
None,
False,
False,
)
if residual is None:
residual = hidden_states
return normed_hidden_states, residual
class FastLinear(nn.Linear):
def __init__(
self,
@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear):
dtype=dtype,
)
def forward(self, input):
return super(TensorParallelColumnLinear, self).forward(input)
class TensorParallelRowLinear(FastLinear):
def __init__(
@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear):
in_features,
out_features,
process_group: torch.distributed.ProcessGroup,
reduce=True,
bias=True,
device=None,
dtype=None,
):
self.process_group = process_group
self.tp_world_size = process_group.size()
self.reduce = reduce
assert in_features % self.tp_world_size == 0
in_features = in_features // self.tp_world_size
@ -88,7 +143,8 @@ class TensorParallelRowLinear(FastLinear):
def forward(self, input: torch.Tensor) -> torch.Tensor:
out = super(TensorParallelRowLinear, self).forward(input)
torch.distributed.all_reduce(out, group=self.process_group)
if self.reduce:
torch.distributed.all_reduce(out, group=self.process_group)
return out
@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding):
class FlashNeoxAttention(torch.nn.Module):
def __init__(
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
self,
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group=None,
reduce=True,
):
super().__init__()
self.num_heads = num_heads
@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module):
process_group=process_group,
)
self.dense = TensorParallelRowLinear(
hidden_size,
hidden_size,
process_group=process_group,
hidden_size, hidden_size, process_group=process_group, reduce=reduce
)
def shuffle_qkv_dims(self):
@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module):
class FlashMLP(nn.Module):
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
def __init__(
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
):
super().__init__()
self.act = (
ACT2FN[act]
@ -330,6 +392,7 @@ class FlashMLP(nn.Module):
intermediate_size,
hidden_size,
process_group=process_group,
reduce=reduce,
)
self.process_group = process_group
@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module):
):
super().__init__()
self.use_parallel_residual = use_parallel_residual
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
self.attention = FlashNeoxAttention(
num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group
num_heads,
hidden_size,
rotary_pct,
rotary_emb_base,
process_group,
reduce=not use_parallel_residual,
)
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
self.mlp = FlashMLP(
act,
hidden_size,
intermediate_size,
process_group,
reduce=not use_parallel_residual,
)
self.process_group = process_group
def forward(
self,
@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q,
):
if self.use_parallel_residual:
# faster input layer norm
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.input_layernorm.weight,
self.input_layernorm.bias,
None,
None,
None,
None,
0.0,
self.input_layernorm.eps,
1.0,
0,
None,
False,
False,
)
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
attn_output = self.attention(
ln1_hidden_states,
@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q,
)
# faster post attention layer norm
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
None,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
None,
None,
None,
None,
0.0,
self.post_attention_layernorm.eps,
1.0,
0,
None,
False,
False,
)
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
mlp_output = self.mlp(ln2_hidden_states)
return mlp_output + attn_output + hidden_states, None
intermediate = mlp_output + attn_output
# Only reduce once and after the addition instead of once per layer
if self.process_group is not None:
torch.distributed.all_reduce(intermediate, group=self.process_group)
return intermediate + hidden_states, None
else:
# faster input layer norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.input_layernorm.weight,
self.input_layernorm.bias,
None,
None,
None,
None,
0.0,
self.input_layernorm.eps,
1.0,
0,
None,
False,
False,
)
hidden_states, residual = self.input_layernorm(hidden_states, residual)
hidden_states = self.attention(
hidden_states,
@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module):
cu_seqlens_q,
)
# faster post attention layer norm
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.post_attention_layernorm.weight,
self.post_attention_layernorm.bias,
None,
None,
None,
None,
0.0,
self.post_attention_layernorm.eps,
1.0,
0,
None,
False,
False,
hidden_states, residual = self.post_attention_layernorm(
hidden_states, residual
)
mlp_output = self.mlp(hidden_states)
@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
for _ in range(config.num_hidden_layers)
]
)
self.final_layer_norm = nn.LayerNorm(
self.final_layer_norm = FastLayerNorm(
config.hidden_size, eps=config.layer_norm_eps
)
@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
cu_seqlens_q,
)
# Faster final layer norm
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
hidden_states,
residual,
self.final_layer_norm.weight,
self.final_layer_norm.bias,
None,
None,
None,
None,
0.0,
self.final_layer_norm.eps,
1.0,
0,
None,
False,
False,
)
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
return hidden_states, past_key_values