Fixup residual, initial block attention config

This commit is contained in:
Daniël de Kok 2024-06-13 10:38:56 +02:00
parent 4ed551abba
commit 5d2b93ba42
1 changed files with 43 additions and 13 deletions

View File

@ -19,6 +19,7 @@
# limitations under the License. # limitations under the License.
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch import torch
import torch.distributed import torch.distributed
@ -50,6 +51,14 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
@dataclass
class BlockSparseAttentionConfig:
block_size: int
homo_head_pattern: bool
num_local_blocks: int
vert_stride: int
def load_attention(config, prefix, weights): def load_attention(config, prefix, weights):
# Only defined in granite. # Only defined in granite.
bias = getattr(config, "attention_bias", False) bias = getattr(config, "attention_bias", False)
@ -69,6 +78,7 @@ class FlashPhi3SmallAttention(torch.nn.Module):
self, self,
prefix: str, prefix: str,
config, config,
layer_id: int,
weights, weights,
): ):
super().__init__() super().__init__()
@ -83,6 +93,9 @@ class FlashPhi3SmallAttention(torch.nn.Module):
device=weights.device, device=weights.device,
) )
if hasattr(config, "mup_use_scaling") and config.mup_use_scaling:
self.softmax_scale = self.head_size / config.mup_attn_multiplier
else:
self.softmax_scale = self.head_size**-0.5 self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0: if self.num_heads % weights.process_group.size() != 0:
@ -102,11 +115,25 @@ class FlashPhi3SmallAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights) self.query_key_value = load_attention(config, prefix, weights)
is_dense = getattr(config, "dense_attention_every_n_layers", False) and (
(layer_id + 1) % config.dense_attention_every_n_layers == 0
)
if is_dense:
self.blocksparse_config = None
else:
self.blocksparse_config = BlockSparseAttentionConfig(
block_size=config.blocksparse_block_size,
homo_head_pattern=config.blocksparse_homo_head_pattern,
num_local_blocks=config.blocksparse_num_local_blocks,
vert_stride=config.blocksparse_vert_stride,
)
self.o_proj = TensorParallelRowLinear.load( self.o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.dense", prefix=f"{prefix}.dense",
weights=weights, weights=weights,
bias=False, bias=True,
) )
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
@ -246,10 +273,13 @@ class Phi3SmallMLP(nn.Module):
class FlashPhi3SmallLayer(nn.Module): class FlashPhi3SmallLayer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, layer_id: int, weights):
super().__init__() super().__init__()
self.self_attn = FlashPhi3SmallAttention( self.self_attn = FlashPhi3SmallAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn",
config=config,
layer_id=layer_id,
weights=weights,
) )
self.mlp = Phi3SmallMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) self.mlp = Phi3SmallMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -267,7 +297,6 @@ class FlashPhi3SmallLayer(nn.Module):
def forward( def forward(
self, self,
hidden_states, hidden_states,
residual,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
@ -277,7 +306,8 @@ class FlashPhi3SmallLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) residual = hidden_states
normed_hidden_states, res = self.input_layernorm(hidden_states, None)
# Self Attention # Self Attention
attn_output = self.self_attn( attn_output = self.self_attn(
@ -294,12 +324,13 @@ class FlashPhi3SmallLayer(nn.Module):
# faster post attention rms norm # faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm( normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res attn_output, residual
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output)
mlp_output = attn_res + mlp_output
return mlp_output, attn_res return mlp_output
class FlashPhi3SmallModel(torch.nn.Module): class FlashPhi3SmallModel(torch.nn.Module):
@ -318,12 +349,13 @@ class FlashPhi3SmallModel(torch.nn.Module):
else f"{prefix}.model.layers.{layer_id}" else f"{prefix}.model.layers.{layer_id}"
), ),
config=config, config=config,
layer_id=layer_id,
weights=weights, weights=weights,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
) )
self.norm = FastLayerNorm.load( self.norm = nn.LayerNorm.load(
prefix=( prefix=(
"model.final_layernorm" "model.final_layernorm"
if not prefix if not prefix
@ -360,11 +392,9 @@ class FlashPhi3SmallModel(torch.nn.Module):
position_ids, max_s, hidden_states.dtype position_ids, max_s, hidden_states.dtype
) )
residual = None
for i, layer in enumerate(self.layers): for i, layer in enumerate(self.layers):
hidden_states, residual = layer( hidden_states = layer(
hidden_states, hidden_states,
residual,
cos, cos,
sin, sin,
cu_seqlen_prefill, cu_seqlen_prefill,
@ -375,7 +405,7 @@ class FlashPhi3SmallModel(torch.nn.Module):
max_s, max_s,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states = self.norm(hidden_states)
return hidden_states return hidden_states