Fixup residual, initial block attention config
This commit is contained in:
parent
4ed551abba
commit
5d2b93ba42
|
@ -19,6 +19,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed
|
import torch.distributed
|
||||||
|
@ -50,6 +51,14 @@ if SYSTEM == "rocm":
|
||||||
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BlockSparseAttentionConfig:
|
||||||
|
block_size: int
|
||||||
|
homo_head_pattern: bool
|
||||||
|
num_local_blocks: int
|
||||||
|
vert_stride: int
|
||||||
|
|
||||||
|
|
||||||
def load_attention(config, prefix, weights):
|
def load_attention(config, prefix, weights):
|
||||||
# Only defined in granite.
|
# Only defined in granite.
|
||||||
bias = getattr(config, "attention_bias", False)
|
bias = getattr(config, "attention_bias", False)
|
||||||
|
@ -69,6 +78,7 @@ class FlashPhi3SmallAttention(torch.nn.Module):
|
||||||
self,
|
self,
|
||||||
prefix: str,
|
prefix: str,
|
||||||
config,
|
config,
|
||||||
|
layer_id: int,
|
||||||
weights,
|
weights,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
@ -83,6 +93,9 @@ class FlashPhi3SmallAttention(torch.nn.Module):
|
||||||
device=weights.device,
|
device=weights.device,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if hasattr(config, "mup_use_scaling") and config.mup_use_scaling:
|
||||||
|
self.softmax_scale = self.head_size / config.mup_attn_multiplier
|
||||||
|
else:
|
||||||
self.softmax_scale = self.head_size**-0.5
|
self.softmax_scale = self.head_size**-0.5
|
||||||
|
|
||||||
if self.num_heads % weights.process_group.size() != 0:
|
if self.num_heads % weights.process_group.size() != 0:
|
||||||
|
@ -102,11 +115,25 @@ class FlashPhi3SmallAttention(torch.nn.Module):
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
self.query_key_value = load_attention(config, prefix, weights)
|
||||||
|
|
||||||
|
is_dense = getattr(config, "dense_attention_every_n_layers", False) and (
|
||||||
|
(layer_id + 1) % config.dense_attention_every_n_layers == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_dense:
|
||||||
|
self.blocksparse_config = None
|
||||||
|
else:
|
||||||
|
self.blocksparse_config = BlockSparseAttentionConfig(
|
||||||
|
block_size=config.blocksparse_block_size,
|
||||||
|
homo_head_pattern=config.blocksparse_homo_head_pattern,
|
||||||
|
num_local_blocks=config.blocksparse_num_local_blocks,
|
||||||
|
vert_stride=config.blocksparse_vert_stride,
|
||||||
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
self.o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.dense",
|
prefix=f"{prefix}.dense",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
|
@ -246,10 +273,13 @@ class Phi3SmallMLP(nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class FlashPhi3SmallLayer(nn.Module):
|
class FlashPhi3SmallLayer(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, layer_id: int, weights):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashPhi3SmallAttention(
|
self.self_attn = FlashPhi3SmallAttention(
|
||||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
prefix=f"{prefix}.self_attn",
|
||||||
|
config=config,
|
||||||
|
layer_id=layer_id,
|
||||||
|
weights=weights,
|
||||||
)
|
)
|
||||||
self.mlp = Phi3SmallMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = Phi3SmallMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||||
|
|
||||||
|
@ -267,7 +297,6 @@ class FlashPhi3SmallLayer(nn.Module):
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
@ -277,7 +306,8 @@ class FlashPhi3SmallLayer(nn.Module):
|
||||||
input_lengths,
|
input_lengths,
|
||||||
max_s,
|
max_s,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
residual = hidden_states
|
||||||
|
normed_hidden_states, res = self.input_layernorm(hidden_states, None)
|
||||||
|
|
||||||
# Self Attention
|
# Self Attention
|
||||||
attn_output = self.self_attn(
|
attn_output = self.self_attn(
|
||||||
|
@ -294,12 +324,13 @@ class FlashPhi3SmallLayer(nn.Module):
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
normed_attn_res_output, attn_res = self.post_attention_layernorm(
|
||||||
attn_output, res
|
attn_output, residual
|
||||||
)
|
)
|
||||||
|
|
||||||
mlp_output = self.mlp(normed_attn_res_output)
|
mlp_output = self.mlp(normed_attn_res_output)
|
||||||
|
mlp_output = attn_res + mlp_output
|
||||||
|
|
||||||
return mlp_output, attn_res
|
return mlp_output
|
||||||
|
|
||||||
|
|
||||||
class FlashPhi3SmallModel(torch.nn.Module):
|
class FlashPhi3SmallModel(torch.nn.Module):
|
||||||
|
@ -318,12 +349,13 @@ class FlashPhi3SmallModel(torch.nn.Module):
|
||||||
else f"{prefix}.model.layers.{layer_id}"
|
else f"{prefix}.model.layers.{layer_id}"
|
||||||
),
|
),
|
||||||
config=config,
|
config=config,
|
||||||
|
layer_id=layer_id,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
)
|
)
|
||||||
for layer_id in range(config.num_hidden_layers)
|
for layer_id in range(config.num_hidden_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
self.norm = FastLayerNorm.load(
|
self.norm = nn.LayerNorm.load(
|
||||||
prefix=(
|
prefix=(
|
||||||
"model.final_layernorm"
|
"model.final_layernorm"
|
||||||
if not prefix
|
if not prefix
|
||||||
|
@ -360,11 +392,9 @@ class FlashPhi3SmallModel(torch.nn.Module):
|
||||||
position_ids, max_s, hidden_states.dtype
|
position_ids, max_s, hidden_states.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
residual = None
|
|
||||||
for i, layer in enumerate(self.layers):
|
for i, layer in enumerate(self.layers):
|
||||||
hidden_states, residual = layer(
|
hidden_states = layer(
|
||||||
hidden_states,
|
hidden_states,
|
||||||
residual,
|
|
||||||
cos,
|
cos,
|
||||||
sin,
|
sin,
|
||||||
cu_seqlen_prefill,
|
cu_seqlen_prefill,
|
||||||
|
@ -375,7 +405,7 @@ class FlashPhi3SmallModel(torch.nn.Module):
|
||||||
max_s,
|
max_s,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states = self.norm(hidden_states)
|
||||||
|
|
||||||
return hidden_states
|
return hidden_states
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue