From 5d2b93ba42529ff42516539f367069b8a3800991 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Thu, 13 Jun 2024 10:38:56 +0200 Subject: [PATCH] Fixup residual, initial block attention config --- .../flash_phi3small_modeling.py | 56 ++++++++++++++----- 1 file changed, 43 insertions(+), 13 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_phi3small_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi3small_modeling.py index 325e8c2d..21827da5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi3small_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi3small_modeling.py @@ -19,6 +19,7 @@ # limitations under the License. from typing import List, Optional, Tuple +from dataclasses import dataclass import torch import torch.distributed @@ -50,6 +51,14 @@ if SYSTEM == "rocm": raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") +@dataclass +class BlockSparseAttentionConfig: + block_size: int + homo_head_pattern: bool + num_local_blocks: int + vert_stride: int + + def load_attention(config, prefix, weights): # Only defined in granite. bias = getattr(config, "attention_bias", False) @@ -69,6 +78,7 @@ class FlashPhi3SmallAttention(torch.nn.Module): self, prefix: str, config, + layer_id: int, weights, ): super().__init__() @@ -83,7 +93,10 @@ class FlashPhi3SmallAttention(torch.nn.Module): device=weights.device, ) - self.softmax_scale = self.head_size**-0.5 + if hasattr(config, "mup_use_scaling") and config.mup_use_scaling: + self.softmax_scale = self.head_size / config.mup_attn_multiplier + else: + self.softmax_scale = self.head_size**-0.5 if self.num_heads % weights.process_group.size() != 0: raise ValueError( @@ -102,11 +115,25 @@ class FlashPhi3SmallAttention(torch.nn.Module): self.query_key_value = load_attention(config, prefix, weights) + is_dense = getattr(config, "dense_attention_every_n_layers", False) and ( + (layer_id + 1) % config.dense_attention_every_n_layers == 0 + ) + + if is_dense: + self.blocksparse_config = None + else: + self.blocksparse_config = BlockSparseAttentionConfig( + block_size=config.blocksparse_block_size, + homo_head_pattern=config.blocksparse_homo_head_pattern, + num_local_blocks=config.blocksparse_num_local_blocks, + vert_stride=config.blocksparse_vert_stride, + ) + self.o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.dense", weights=weights, - bias=False, + bias=True, ) self.num_groups = self.num_heads // self.num_key_value_heads @@ -246,10 +273,13 @@ class Phi3SmallMLP(nn.Module): class FlashPhi3SmallLayer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, layer_id: int, weights): super().__init__() self.self_attn = FlashPhi3SmallAttention( - prefix=f"{prefix}.self_attn", config=config, weights=weights + prefix=f"{prefix}.self_attn", + config=config, + layer_id=layer_id, + weights=weights, ) self.mlp = Phi3SmallMLP(prefix=f"{prefix}.mlp", config=config, weights=weights) @@ -267,7 +297,6 @@ class FlashPhi3SmallLayer(nn.Module): def forward( self, hidden_states, - residual, cos, sin, cu_seqlen_prefill, @@ -277,7 +306,8 @@ class FlashPhi3SmallLayer(nn.Module): input_lengths, max_s, ): - normed_hidden_states, res = self.input_layernorm(hidden_states, residual) + residual = hidden_states + normed_hidden_states, res = self.input_layernorm(hidden_states, None) # Self Attention attn_output = self.self_attn( @@ -294,12 +324,13 @@ class FlashPhi3SmallLayer(nn.Module): # faster post attention rms norm normed_attn_res_output, attn_res = self.post_attention_layernorm( - attn_output, res + attn_output, residual ) mlp_output = self.mlp(normed_attn_res_output) + mlp_output = attn_res + mlp_output - return mlp_output, attn_res + return mlp_output class FlashPhi3SmallModel(torch.nn.Module): @@ -318,12 +349,13 @@ class FlashPhi3SmallModel(torch.nn.Module): else f"{prefix}.model.layers.{layer_id}" ), config=config, + layer_id=layer_id, weights=weights, ) for layer_id in range(config.num_hidden_layers) ] ) - self.norm = FastLayerNorm.load( + self.norm = nn.LayerNorm.load( prefix=( "model.final_layernorm" if not prefix @@ -360,11 +392,9 @@ class FlashPhi3SmallModel(torch.nn.Module): position_ids, max_s, hidden_states.dtype ) - residual = None for i, layer in enumerate(self.layers): - hidden_states, residual = layer( + hidden_states = layer( hidden_states, - residual, cos, sin, cu_seqlen_prefill, @@ -375,7 +405,7 @@ class FlashPhi3SmallModel(torch.nn.Module): max_s, ) - hidden_states, _ = self.norm(hidden_states, residual) + hidden_states = self.norm(hidden_states) return hidden_states