diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 6a6b2e0a..40ccb576 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -18,9 +18,10 @@ # See the License for the specific language governing permissions and # limitations under the License. +from typing import List, Optional, Tuple + import torch import torch.distributed - from torch import nn from transformers.activations import ACT2FN from typing import Optional, List, Tuple diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index 52ea3ae1..fa463a19 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -1,26 +1,21 @@ +from typing import List, Optional, Tuple + import torch import torch.distributed - from torch import nn -from transformers.modeling_utils import PreTrainedModel from transformers.configuration_utils import PretrainedConfig -from typing import Optional, List, Tuple +from transformers.modeling_utils import PreTrainedModel -from text_generation_server.utils import paged_attention, flash_attn -from text_generation_server.utils.flash_attn import attention from text_generation_server.layers import ( - TensorParallelRowLinear, + SpeculativeHead, TensorParallelColumnLinear, TensorParallelEmbedding, - SpeculativeHead, + TensorParallelRowLinear, get_linear, ) -from text_generation_server.layers.layernorm import ( - FastLayerNorm, -) -from text_generation_server.layers.rotary import ( - PositionRotaryEmbedding, -) +from text_generation_server.layers.layernorm import FastLayerNorm +from text_generation_server.layers.rotary import PositionRotaryEmbedding +from text_generation_server.utils import flash_attn, paged_attention def load_row(config, prefix: str, weights, bias: bool): @@ -52,6 +47,7 @@ class RWConfig(PretrainedConfig): hidden_size=64, num_hidden_layers=None, num_attention_heads=None, + num_ln_in_prallel_attention=None, layer_norm_epsilon=1e-5, initializer_range=0.02, use_cache=True, @@ -65,6 +61,7 @@ class RWConfig(PretrainedConfig): new_decoder_architecture=None, bias=False, parallel_attn=False, + rope_theta=10_000.0, **kwargs, ): if alibi: @@ -75,6 +72,7 @@ class RWConfig(PretrainedConfig): self.model_type = model_type self.alibi = False self.rotary = True + self.rope_theta = rope_theta self.vocab_size = vocab_size # Backward compatibility with n_embed kwarg @@ -91,6 +89,7 @@ class RWConfig(PretrainedConfig): else kwargs.pop("n_head", 8) ) self.layer_norm_epsilon = layer_norm_epsilon + self.num_ln_in_parallel_attention = num_ln_in_prallel_attention self.initializer_range = initializer_range self.use_cache = use_cache self.hidden_dropout = hidden_dropout @@ -132,9 +131,13 @@ class FlashRWAttention(torch.nn.Module): self.num_heads_kv = config.n_head_kv self.hidden_size = config.hidden_size self.head_size = self.hidden_size // self.num_heads + self.rope_theta = config.rope_theta self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=10000.0, device=weights.device + config=config, + dim=self.head_size, + base=self.rope_theta, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5) @@ -244,9 +247,13 @@ class FlashRWLargeAttention(torch.nn.Module): self.hidden_size = hidden_size self.head_size = hidden_size // num_heads self.num_groups = num_groups + self.rope_theta = config.rope_theta self.rotary_emb = PositionRotaryEmbedding.static( - config=config, dim=self.head_size, base=10000.0, device=weights.device + config=config, + dim=self.head_size, + base=self.rope_theta, + device=weights.device, ) self.softmax_scale = self.head_size ** (-0.5) @@ -257,7 +264,7 @@ class FlashRWLargeAttention(torch.nn.Module): if process_group.size() > self.num_groups: raise NotImplementedError( - f"Tensor Parallelism is not implemented for world_size > n groups" + "Tensor Parallelism is not implemented for world_size > n groups" ) if self.num_groups % process_group.size() != 0: raise NotImplementedError( @@ -459,29 +466,61 @@ class FlashRWLayer(nn.Module): max_s, ) - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual - ) + if self.post_attention_layernorm is not None: + hidden_states, residual = self.post_attention_layernorm( + hidden_states, residual + ) mlp_output = self.mlp(hidden_states) return mlp_output, residual +class FlashRWLayerNorm(nn.Module): + def __init__(self, config, prefix, weights): + super().__init__() + self.num_ln = config.num_ln_in_parallel_attn + + if self.num_ln == 1: + self.input_ln = FastLayerNorm.load( + prefix=f"{prefix}.input_layernorm", + weights=weights, + eps=config.layer_norm_epsilon, + ) + elif self.num_ln == 2: + self.ln_attn = FastLayerNorm.load( + prefix=f"{prefix}.ln_attn", + weights=weights, + eps=config.layer_norm_epsilon, + ) + self.ln_mlp = FastLayerNorm.load( + prefix=f"{prefix}.ln_mlp", + weights=weights, + eps=config.layer_norm_epsilon, + ) + else: + raise ValueError("Number of layer norms can either be 1 or 2.") + + def forward( + self, + hidden_states, + residual, + ): + if self.num_ln == 1: + ln_hidden_states, residual = self.input_ln(hidden_states, residual) + return ln_hidden_states, ln_hidden_states, residual + elif self.num_ln == 2: + ln_attn, residual = self.ln_attn(hidden_states, residual) + ln_mlp, _ = self.ln_mlp(residual) + return ln_attn, ln_mlp, residual + + class FlashRWLargeLayer(nn.Module): def __init__(self, layer_id, config, weights): super().__init__() prefix = f"transformer.h.{layer_id}" - self.ln_attn = FastLayerNorm.load( - prefix=f"{prefix}.ln_attn", - weights=weights, - eps=config.layer_norm_epsilon, - ) - self.ln_mlp = FastLayerNorm.load( - prefix=f"{prefix}.ln_mlp", - weights=weights, - eps=config.layer_norm_epsilon, - ) + + self.ln_layer = FlashRWLayerNorm(config, prefix, weights) self.self_attention = FlashRWLargeAttention( config, @@ -507,8 +546,8 @@ class FlashRWLargeLayer(nn.Module): input_lengths, max_s, ): - ln_attn, residual = self.ln_attn(hidden_states, residual) - ln_mlp, _ = self.ln_mlp(residual) + # Layer norm. + ln_attn, ln_mlp, residual = self.ln_layer(hidden_states, residual) # Self attention. attn_output = self.self_attention(