Fixup residual, initial block attention config
This commit is contained in:
parent
4ed551abba
commit
5d2b93ba42
|
@ -19,6 +19,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
|
@ -50,6 +51,14 @@ if SYSTEM == "rocm":
|
|||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||
|
||||
|
||||
@dataclass
|
||||
class BlockSparseAttentionConfig:
|
||||
block_size: int
|
||||
homo_head_pattern: bool
|
||||
num_local_blocks: int
|
||||
vert_stride: int
|
||||
|
||||
|
||||
def load_attention(config, prefix, weights):
|
||||
# Only defined in granite.
|
||||
bias = getattr(config, "attention_bias", False)
|
||||
|
@ -69,6 +78,7 @@ class FlashPhi3SmallAttention(torch.nn.Module):
|
|||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
layer_id: int,
|
||||
weights,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -83,6 +93,9 @@ class FlashPhi3SmallAttention(torch.nn.Module):
|
|||
device=weights.device,
|
||||
)
|
||||
|
||||
if hasattr(config, "mup_use_scaling") and config.mup_use_scaling:
|
||||
self.softmax_scale = self.head_size / config.mup_attn_multiplier
|
||||
else:
|
||||
self.softmax_scale = self.head_size**-0.5
|
||||
|
||||
if self.num_heads % weights.process_group.size() != 0:
|
||||
|
@ -102,11 +115,25 @@ class FlashPhi3SmallAttention(torch.nn.Module):
|
|||
|
||||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
|
||||
is_dense = getattr(config, "dense_attention_every_n_layers", False) and (
|
||||
(layer_id + 1) % config.dense_attention_every_n_layers == 0
|
||||
)
|
||||
|
||||
if is_dense:
|
||||
self.blocksparse_config = None
|
||||
else:
|
||||
self.blocksparse_config = BlockSparseAttentionConfig(
|
||||
block_size=config.blocksparse_block_size,
|
||||
homo_head_pattern=config.blocksparse_homo_head_pattern,
|
||||
num_local_blocks=config.blocksparse_num_local_blocks,
|
||||
vert_stride=config.blocksparse_vert_stride,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.dense",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
bias=True,
|
||||
)
|
||||
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
|
@ -246,10 +273,13 @@ class Phi3SmallMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashPhi3SmallLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, layer_id: int, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashPhi3SmallAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
weights=weights,
|
||||
)
|
||||
self.mlp = Phi3SmallMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
|
@ -267,7 +297,6 @@ class FlashPhi3SmallLayer(nn.Module):
|
|||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
|
@ -277,7 +306,8 @@ class FlashPhi3SmallLayer(nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
residual = hidden_states
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, None)
|
||||
|
||||
# Self Attention
|
||||
attn_output = self.self_attn(
|
||||
|
@ -294,12 +324,13 @@ class FlashPhi3SmallLayer(nn.Module):
|
|||
|
||||
# faster post attention rms norm
|
||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||
attn_output, res
|
||||
attn_output, residual
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = attn_res + mlp_output
|
||||
|
||||
return mlp_output, attn_res
|
||||
return mlp_output
|
||||
|
||||
|
||||
class FlashPhi3SmallModel(torch.nn.Module):
|
||||
|
@ -318,12 +349,13 @@ class FlashPhi3SmallModel(torch.nn.Module):
|
|||
else f"{prefix}.model.layers.{layer_id}"
|
||||
),
|
||||
config=config,
|
||||
layer_id=layer_id,
|
||||
weights=weights,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
)
|
||||
self.norm = FastLayerNorm.load(
|
||||
self.norm = nn.LayerNorm.load(
|
||||
prefix=(
|
||||
"model.final_layernorm"
|
||||
if not prefix
|
||||
|
@ -360,11 +392,9 @@ class FlashPhi3SmallModel(torch.nn.Module):
|
|||
position_ids, max_s, hidden_states.dtype
|
||||
)
|
||||
|
||||
residual = None
|
||||
for i, layer in enumerate(self.layers):
|
||||
hidden_states, residual = layer(
|
||||
hidden_states = layer(
|
||||
hidden_states,
|
||||
residual,
|
||||
cos,
|
||||
sin,
|
||||
cu_seqlen_prefill,
|
||||
|
@ -375,7 +405,7 @@ class FlashPhi3SmallModel(torch.nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
hidden_states = self.norm(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
|
Loading…
Reference in New Issue