Fixup residual, initial block attention config

This commit is contained in:
Daniël de Kok 2024-06-13 10:38:56 +02:00
parent 4ed551abba
commit 5d2b93ba42
1 changed files with 43 additions and 13 deletions

View File

@ -19,6 +19,7 @@
# limitations under the License.
from typing import List, Optional, Tuple
from dataclasses import dataclass
import torch
import torch.distributed
@ -50,6 +51,14 @@ if SYSTEM == "rocm":
raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}")
@dataclass
class BlockSparseAttentionConfig:
block_size: int
homo_head_pattern: bool
num_local_blocks: int
vert_stride: int
def load_attention(config, prefix, weights):
# Only defined in granite.
bias = getattr(config, "attention_bias", False)
@ -69,6 +78,7 @@ class FlashPhi3SmallAttention(torch.nn.Module):
self,
prefix: str,
config,
layer_id: int,
weights,
):
super().__init__()
@ -83,7 +93,10 @@ class FlashPhi3SmallAttention(torch.nn.Module):
device=weights.device,
)
self.softmax_scale = self.head_size**-0.5
if hasattr(config, "mup_use_scaling") and config.mup_use_scaling:
self.softmax_scale = self.head_size / config.mup_attn_multiplier
else:
self.softmax_scale = self.head_size**-0.5
if self.num_heads % weights.process_group.size() != 0:
raise ValueError(
@ -102,11 +115,25 @@ class FlashPhi3SmallAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
is_dense = getattr(config, "dense_attention_every_n_layers", False) and (
(layer_id + 1) % config.dense_attention_every_n_layers == 0
)
if is_dense:
self.blocksparse_config = None
else:
self.blocksparse_config = BlockSparseAttentionConfig(
block_size=config.blocksparse_block_size,
homo_head_pattern=config.blocksparse_homo_head_pattern,
num_local_blocks=config.blocksparse_num_local_blocks,
vert_stride=config.blocksparse_vert_stride,
)
self.o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.dense",
weights=weights,
bias=False,
bias=True,
)
self.num_groups = self.num_heads // self.num_key_value_heads
@ -246,10 +273,13 @@ class Phi3SmallMLP(nn.Module):
class FlashPhi3SmallLayer(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, layer_id: int, weights):
super().__init__()
self.self_attn = FlashPhi3SmallAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights
prefix=f"{prefix}.self_attn",
config=config,
layer_id=layer_id,
weights=weights,
)
self.mlp = Phi3SmallMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
@ -267,7 +297,6 @@ class FlashPhi3SmallLayer(nn.Module):
def forward(
self,
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
@ -277,7 +306,8 @@ class FlashPhi3SmallLayer(nn.Module):
input_lengths,
max_s,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
residual = hidden_states
normed_hidden_states, res = self.input_layernorm(hidden_states, None)
# Self Attention
attn_output = self.self_attn(
@ -294,12 +324,13 @@ class FlashPhi3SmallLayer(nn.Module):
# faster post attention rms norm
normed_attn_res_output, attn_res = self.post_attention_layernorm(
attn_output, res
attn_output, residual
)
mlp_output = self.mlp(normed_attn_res_output)
mlp_output = attn_res + mlp_output
return mlp_output, attn_res
return mlp_output
class FlashPhi3SmallModel(torch.nn.Module):
@ -318,12 +349,13 @@ class FlashPhi3SmallModel(torch.nn.Module):
else f"{prefix}.model.layers.{layer_id}"
),
config=config,
layer_id=layer_id,
weights=weights,
)
for layer_id in range(config.num_hidden_layers)
]
)
self.norm = FastLayerNorm.load(
self.norm = nn.LayerNorm.load(
prefix=(
"model.final_layernorm"
if not prefix
@ -360,11 +392,9 @@ class FlashPhi3SmallModel(torch.nn.Module):
position_ids, max_s, hidden_states.dtype
)
residual = None
for i, layer in enumerate(self.layers):
hidden_states, residual = layer(
hidden_states = layer(
hidden_states,
residual,
cos,
sin,
cu_seqlen_prefill,
@ -375,7 +405,7 @@ class FlashPhi3SmallModel(torch.nn.Module):
max_s,
)
hidden_states, _ = self.norm(hidden_states, residual)
hidden_states = self.norm(hidden_states)
return hidden_states