feat(server): reduce mlp and attn in one op for flash neox (#145)
This commit is contained in:
parent
f000068944
commit
c9bdaa8b73
|
@ -1,3 +1,23 @@
|
|||
# coding=utf-8
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
||||
|
@ -16,6 +36,42 @@ import dropout_layer_norm
|
|||
from flash_attn.layers.rotary import RotaryEmbedding
|
||||
|
||||
|
||||
class FastLayerNorm(nn.LayerNorm):
|
||||
def forward(self, hidden_states, residual=None):
|
||||
if hidden_states.shape[-1] > 6144:
|
||||
if residual is not None:
|
||||
hidden_states += residual
|
||||
residual = hidden_states
|
||||
|
||||
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||
else:
|
||||
(
|
||||
normed_hidden_states,
|
||||
residual,
|
||||
*rest,
|
||||
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.weight,
|
||||
self.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
if residual is None:
|
||||
residual = hidden_states
|
||||
|
||||
return normed_hidden_states, residual
|
||||
|
||||
|
||||
class FastLinear(nn.Linear):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear):
|
|||
dtype=dtype,
|
||||
)
|
||||
|
||||
def forward(self, input):
|
||||
return super(TensorParallelColumnLinear, self).forward(input)
|
||||
|
||||
|
||||
class TensorParallelRowLinear(FastLinear):
|
||||
def __init__(
|
||||
|
@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear):
|
|||
in_features,
|
||||
out_features,
|
||||
process_group: torch.distributed.ProcessGroup,
|
||||
reduce=True,
|
||||
bias=True,
|
||||
device=None,
|
||||
dtype=None,
|
||||
):
|
||||
self.process_group = process_group
|
||||
self.tp_world_size = process_group.size()
|
||||
self.reduce = reduce
|
||||
assert in_features % self.tp_world_size == 0
|
||||
in_features = in_features // self.tp_world_size
|
||||
|
||||
|
@ -88,7 +143,8 @@ class TensorParallelRowLinear(FastLinear):
|
|||
|
||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||
out = super(TensorParallelRowLinear, self).forward(input)
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
if self.reduce:
|
||||
torch.distributed.all_reduce(out, group=self.process_group)
|
||||
|
||||
return out
|
||||
|
||||
|
@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding):
|
|||
|
||||
class FlashNeoxAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
|
||||
self,
|
||||
num_heads,
|
||||
hidden_size,
|
||||
rotary_pct,
|
||||
rotary_emb_base,
|
||||
process_group=None,
|
||||
reduce=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = num_heads
|
||||
|
@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
process_group=process_group,
|
||||
)
|
||||
self.dense = TensorParallelRowLinear(
|
||||
hidden_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
hidden_size, hidden_size, process_group=process_group, reduce=reduce
|
||||
)
|
||||
|
||||
def shuffle_qkv_dims(self):
|
||||
|
@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashMLP(nn.Module):
|
||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
||||
def __init__(
|
||||
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
|
||||
):
|
||||
super().__init__()
|
||||
self.act = (
|
||||
ACT2FN[act]
|
||||
|
@ -330,6 +392,7 @@ class FlashMLP(nn.Module):
|
|||
intermediate_size,
|
||||
hidden_size,
|
||||
process_group=process_group,
|
||||
reduce=reduce,
|
||||
)
|
||||
self.process_group = process_group
|
||||
|
||||
|
@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module):
|
|||
):
|
||||
super().__init__()
|
||||
self.use_parallel_residual = use_parallel_residual
|
||||
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||
self.attention = FlashNeoxAttention(
|
||||
num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group
|
||||
num_heads,
|
||||
hidden_size,
|
||||
rotary_pct,
|
||||
rotary_emb_base,
|
||||
process_group,
|
||||
reduce=not use_parallel_residual,
|
||||
)
|
||||
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
|
||||
self.mlp = FlashMLP(
|
||||
act,
|
||||
hidden_size,
|
||||
intermediate_size,
|
||||
process_group,
|
||||
reduce=not use_parallel_residual,
|
||||
)
|
||||
self.process_group = process_group
|
||||
|
||||
def forward(
|
||||
self,
|
||||
|
@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module):
|
|||
cu_seqlens_q,
|
||||
):
|
||||
if self.use_parallel_residual:
|
||||
# faster input layer norm
|
||||
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
None,
|
||||
self.input_layernorm.weight,
|
||||
self.input_layernorm.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.input_layernorm.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||
|
||||
attn_output = self.attention(
|
||||
ln1_hidden_states,
|
||||
|
@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module):
|
|||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
# faster post attention layer norm
|
||||
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
None,
|
||||
self.post_attention_layernorm.weight,
|
||||
self.post_attention_layernorm.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.post_attention_layernorm.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||
|
||||
mlp_output = self.mlp(ln2_hidden_states)
|
||||
return mlp_output + attn_output + hidden_states, None
|
||||
intermediate = mlp_output + attn_output
|
||||
|
||||
# Only reduce once and after the addition instead of once per layer
|
||||
if self.process_group is not None:
|
||||
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||
|
||||
return intermediate + hidden_states, None
|
||||
else:
|
||||
# faster input layer norm
|
||||
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.input_layernorm.weight,
|
||||
self.input_layernorm.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.input_layernorm.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
hidden_states = self.attention(
|
||||
hidden_states,
|
||||
|
@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module):
|
|||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
# faster post attention layer norm
|
||||
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.post_attention_layernorm.weight,
|
||||
self.post_attention_layernorm.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.post_attention_layernorm.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
hidden_states, residual = self.post_attention_layernorm(
|
||||
hidden_states, residual
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(hidden_states)
|
||||
|
@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
for _ in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.final_layer_norm = nn.LayerNorm(
|
||||
self.final_layer_norm = FastLayerNorm(
|
||||
config.hidden_size, eps=config.layer_norm_eps
|
||||
)
|
||||
|
||||
|
@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
|||
cu_seqlens_q,
|
||||
)
|
||||
|
||||
# Faster final layer norm
|
||||
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
||||
hidden_states,
|
||||
residual,
|
||||
self.final_layer_norm.weight,
|
||||
self.final_layer_norm.bias,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
0.0,
|
||||
self.final_layer_norm.eps,
|
||||
1.0,
|
||||
0,
|
||||
None,
|
||||
False,
|
||||
False,
|
||||
)
|
||||
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||
|
||||
return hidden_states, past_key_values
|
||||
|
||||
|
|
Loading…
Reference in New Issue