feat(server): reduce mlp and attn in one op for flash neox (#145)
This commit is contained in:
parent
f000068944
commit
c9bdaa8b73
|
@ -1,3 +1,23 @@
|
||||||
|
# coding=utf-8
|
||||||
|
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||||
|
# and OPT implementations in this library. It has been modified from its
|
||||||
|
# original forms to accommodate minor architectural differences compared
|
||||||
|
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
|
||||||
|
@ -16,6 +36,42 @@ import dropout_layer_norm
|
||||||
from flash_attn.layers.rotary import RotaryEmbedding
|
from flash_attn.layers.rotary import RotaryEmbedding
|
||||||
|
|
||||||
|
|
||||||
|
class FastLayerNorm(nn.LayerNorm):
|
||||||
|
def forward(self, hidden_states, residual=None):
|
||||||
|
if hidden_states.shape[-1] > 6144:
|
||||||
|
if residual is not None:
|
||||||
|
hidden_states += residual
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return super(FastLayerNorm, self).forward(hidden_states), residual
|
||||||
|
else:
|
||||||
|
(
|
||||||
|
normed_hidden_states,
|
||||||
|
residual,
|
||||||
|
*rest,
|
||||||
|
) = dropout_layer_norm.dropout_add_ln_fwd(
|
||||||
|
hidden_states,
|
||||||
|
residual,
|
||||||
|
self.weight,
|
||||||
|
self.bias,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
None,
|
||||||
|
0.0,
|
||||||
|
self.eps,
|
||||||
|
1.0,
|
||||||
|
0,
|
||||||
|
None,
|
||||||
|
False,
|
||||||
|
False,
|
||||||
|
)
|
||||||
|
if residual is None:
|
||||||
|
residual = hidden_states
|
||||||
|
|
||||||
|
return normed_hidden_states, residual
|
||||||
|
|
||||||
|
|
||||||
class FastLinear(nn.Linear):
|
class FastLinear(nn.Linear):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -59,9 +115,6 @@ class TensorParallelColumnLinear(FastLinear):
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, input):
|
|
||||||
return super(TensorParallelColumnLinear, self).forward(input)
|
|
||||||
|
|
||||||
|
|
||||||
class TensorParallelRowLinear(FastLinear):
|
class TensorParallelRowLinear(FastLinear):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -69,12 +122,14 @@ class TensorParallelRowLinear(FastLinear):
|
||||||
in_features,
|
in_features,
|
||||||
out_features,
|
out_features,
|
||||||
process_group: torch.distributed.ProcessGroup,
|
process_group: torch.distributed.ProcessGroup,
|
||||||
|
reduce=True,
|
||||||
bias=True,
|
bias=True,
|
||||||
device=None,
|
device=None,
|
||||||
dtype=None,
|
dtype=None,
|
||||||
):
|
):
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
self.tp_world_size = process_group.size()
|
self.tp_world_size = process_group.size()
|
||||||
|
self.reduce = reduce
|
||||||
assert in_features % self.tp_world_size == 0
|
assert in_features % self.tp_world_size == 0
|
||||||
in_features = in_features // self.tp_world_size
|
in_features = in_features // self.tp_world_size
|
||||||
|
|
||||||
|
@ -88,7 +143,8 @@ class TensorParallelRowLinear(FastLinear):
|
||||||
|
|
||||||
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
||||||
out = super(TensorParallelRowLinear, self).forward(input)
|
out = super(TensorParallelRowLinear, self).forward(input)
|
||||||
torch.distributed.all_reduce(out, group=self.process_group)
|
if self.reduce:
|
||||||
|
torch.distributed.all_reduce(out, group=self.process_group)
|
||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -196,7 +252,13 @@ class PositionRotaryEmbedding(RotaryEmbedding):
|
||||||
|
|
||||||
class FlashNeoxAttention(torch.nn.Module):
|
class FlashNeoxAttention(torch.nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group=None
|
self,
|
||||||
|
num_heads,
|
||||||
|
hidden_size,
|
||||||
|
rotary_pct,
|
||||||
|
rotary_emb_base,
|
||||||
|
process_group=None,
|
||||||
|
reduce=True,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
|
@ -218,9 +280,7 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
)
|
)
|
||||||
self.dense = TensorParallelRowLinear(
|
self.dense = TensorParallelRowLinear(
|
||||||
hidden_size,
|
hidden_size, hidden_size, process_group=process_group, reduce=reduce
|
||||||
hidden_size,
|
|
||||||
process_group=process_group,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
def shuffle_qkv_dims(self):
|
def shuffle_qkv_dims(self):
|
||||||
|
@ -309,7 +369,9 @@ class FlashNeoxAttention(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashMLP(nn.Module):
|
class FlashMLP(nn.Module):
|
||||||
def __init__(self, act, hidden_size, intermediate_size, process_group=None):
|
def __init__(
|
||||||
|
self, act, hidden_size, intermediate_size, process_group=None, reduce=True
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.act = (
|
self.act = (
|
||||||
ACT2FN[act]
|
ACT2FN[act]
|
||||||
|
@ -330,6 +392,7 @@ class FlashMLP(nn.Module):
|
||||||
intermediate_size,
|
intermediate_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
process_group=process_group,
|
process_group=process_group,
|
||||||
|
reduce=reduce,
|
||||||
)
|
)
|
||||||
self.process_group = process_group
|
self.process_group = process_group
|
||||||
|
|
||||||
|
@ -355,12 +418,24 @@ class FlashNeoXLayer(nn.Module):
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.use_parallel_residual = use_parallel_residual
|
self.use_parallel_residual = use_parallel_residual
|
||||||
self.input_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
self.input_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
self.post_attention_layernorm = nn.LayerNorm(hidden_size, eps=layer_norm_eps)
|
self.post_attention_layernorm = FastLayerNorm(hidden_size, eps=layer_norm_eps)
|
||||||
self.attention = FlashNeoxAttention(
|
self.attention = FlashNeoxAttention(
|
||||||
num_heads, hidden_size, rotary_pct, rotary_emb_base, process_group
|
num_heads,
|
||||||
|
hidden_size,
|
||||||
|
rotary_pct,
|
||||||
|
rotary_emb_base,
|
||||||
|
process_group,
|
||||||
|
reduce=not use_parallel_residual,
|
||||||
)
|
)
|
||||||
self.mlp = FlashMLP(act, hidden_size, intermediate_size, process_group)
|
self.mlp = FlashMLP(
|
||||||
|
act,
|
||||||
|
hidden_size,
|
||||||
|
intermediate_size,
|
||||||
|
process_group,
|
||||||
|
reduce=not use_parallel_residual,
|
||||||
|
)
|
||||||
|
self.process_group = process_group
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
|
@ -375,24 +450,7 @@ class FlashNeoXLayer(nn.Module):
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
):
|
):
|
||||||
if self.use_parallel_residual:
|
if self.use_parallel_residual:
|
||||||
# faster input layer norm
|
ln1_hidden_states, _ = self.input_layernorm(hidden_states)
|
||||||
ln1_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
None,
|
|
||||||
self.input_layernorm.weight,
|
|
||||||
self.input_layernorm.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.input_layernorm.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
attn_output = self.attention(
|
attn_output = self.attention(
|
||||||
ln1_hidden_states,
|
ln1_hidden_states,
|
||||||
|
@ -405,46 +463,18 @@ class FlashNeoXLayer(nn.Module):
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention layer norm
|
ln2_hidden_states, _ = self.post_attention_layernorm(hidden_states)
|
||||||
ln2_hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
None,
|
|
||||||
self.post_attention_layernorm.weight,
|
|
||||||
self.post_attention_layernorm.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.post_attention_layernorm.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
mlp_output = self.mlp(ln2_hidden_states)
|
mlp_output = self.mlp(ln2_hidden_states)
|
||||||
return mlp_output + attn_output + hidden_states, None
|
intermediate = mlp_output + attn_output
|
||||||
|
|
||||||
|
# Only reduce once and after the addition instead of once per layer
|
||||||
|
if self.process_group is not None:
|
||||||
|
torch.distributed.all_reduce(intermediate, group=self.process_group)
|
||||||
|
|
||||||
|
return intermediate + hidden_states, None
|
||||||
else:
|
else:
|
||||||
# faster input layer norm
|
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
||||||
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.input_layernorm.weight,
|
|
||||||
self.input_layernorm.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.input_layernorm.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = self.attention(
|
hidden_states = self.attention(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
|
@ -457,23 +487,8 @@ class FlashNeoXLayer(nn.Module):
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention layer norm
|
hidden_states, residual = self.post_attention_layernorm(
|
||||||
hidden_states, residual, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
hidden_states, residual
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.post_attention_layernorm.weight,
|
|
||||||
self.post_attention_layernorm.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.post_attention_layernorm.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(hidden_states)
|
mlp_output = self.mlp(hidden_states)
|
||||||
|
@ -523,7 +538,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
for _ in range(config.num_hidden_layers)
|
for _ in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.final_layer_norm = nn.LayerNorm(
|
self.final_layer_norm = FastLayerNorm(
|
||||||
config.hidden_size, eps=config.layer_norm_eps
|
config.hidden_size, eps=config.layer_norm_eps
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -603,24 +618,7 @@ class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel):
|
||||||
cu_seqlens_q,
|
cu_seqlens_q,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Faster final layer norm
|
hidden_states, _ = self.final_layer_norm(hidden_states, residual)
|
||||||
hidden_states, *rest = dropout_layer_norm.dropout_add_ln_fwd(
|
|
||||||
hidden_states,
|
|
||||||
residual,
|
|
||||||
self.final_layer_norm.weight,
|
|
||||||
self.final_layer_norm.bias,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
None,
|
|
||||||
0.0,
|
|
||||||
self.final_layer_norm.eps,
|
|
||||||
1.0,
|
|
||||||
0,
|
|
||||||
None,
|
|
||||||
False,
|
|
||||||
False,
|
|
||||||
)
|
|
||||||
|
|
||||||
return hidden_states, past_key_values
|
return hidden_states, past_key_values
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue