From 0b7df771783325d1c5c502cc0a29aac3f5487b26 Mon Sep 17 00:00:00 2001 From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com> Date: Thu, 26 Sep 2024 10:54:08 +0200 Subject: [PATCH] Add LoRA adapters support for Gemma2 (#2567) * Add LoRA adapters support for Gemma2 * Make `black` formatting happy --- .../custom_modeling/flash_gemma2_modeling.py | 84 +++++++++++++++---- 1 file changed, 70 insertions(+), 14 deletions(-) diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index e12bff00..887e187e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -38,6 +38,8 @@ from text_generation_server.layers import ( TensorParallelEmbedding, SpeculativeHead, get_linear, + TensorParallelMultiAdapterLinear, + TensorParallelAdapterRowLinear, ) from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.layernorm import ( @@ -161,7 +163,9 @@ def _load_gqa(config, prefix: str, weights): class FlashGemma2Attention(torch.nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): super().__init__() self.num_heads = config.num_attention_heads self.head_size = config.head_dim @@ -192,14 +196,32 @@ class FlashGemma2Attention(torch.nn.Module): ) self.softcap = config.attn_logit_softcapping - self.query_key_value = load_attention(config, prefix, weights) + query_key_value = load_attention(config, prefix, weights) + self.query_key_value = TensorParallelMultiAdapterLinear.load( + query_key_value, + layer_id, + ["q_proj", "k_proj", "v_proj"], + sizes=[ + self.head_size * config.num_attention_heads, + self.head_size * config.num_key_value_heads, + self.head_size * config.num_key_value_heads, + ], + process_group=weights.process_group, + ) - self.o_proj = TensorParallelRowLinear.load( + o_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.o_proj", weights=weights, bias=False, ) + self.o_proj = TensorParallelAdapterRowLinear.load( + o_proj, + layer_id, + "o_proj", + process_group=weights.process_group, + ) + self.num_groups = self.num_heads // self.num_key_value_heads self.kv_head_mapping = torch.arange( 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device @@ -216,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module): slots, seqlen, max_s, + adapter_data, ): - qkv = self.query_key_value(hidden_states) + qkv = self.query_key_value(hidden_states, adapter_data) query, kv = qkv.split( [ self.head_size * self.num_heads, @@ -260,11 +283,13 @@ class FlashGemma2Attention(torch.nn.Module): softcap=self.softcap, ) - return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) + return self.o_proj( + attn_output.view(-1, self.num_heads * self.head_size), adapter_data + ) class Gemma2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix, config, weights, layer_id): super().__init__() act = config.hidden_activation self.act = ( @@ -278,40 +303,65 @@ class Gemma2MLP(nn.Module): ) ) # Fuse gate and up proj - self.gate_up_proj = TensorParallelColumnLinear.load_multi( + gate_up_proj = TensorParallelColumnLinear.load_multi( config, prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], weights=weights, dim=0, bias=False, ) - self.down_proj = TensorParallelRowLinear.load( + self.gate_up_proj = TensorParallelMultiAdapterLinear.load( + gate_up_proj, + layer_id, + ["gate_proj", "up_proj"], + sizes=[ + config.intermediate_size, + config.intermediate_size, + ], + process_group=weights.process_group, + ) + + down_proj = TensorParallelRowLinear.load( config, prefix=f"{prefix}.down_proj", weights=weights, bias=False, ) + self.down_proj = TensorParallelAdapterRowLinear.load( + down_proj, + layer_id, + "down_proj", + process_group=weights.process_group, + ) + self.intermediate_size = ( config.intermediate_size // weights.process_group.size() ) - def forward(self, hidden_states): - gate_up_states = self.gate_up_proj(hidden_states) + def forward(self, hidden_states, adapter_data): + gate_up_states = self.gate_up_proj(hidden_states, adapter_data) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) - return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) + return self.down_proj( + self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data + ) class FlashGemma2Layer(nn.Module): - def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool): + def __init__( + self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool + ): super().__init__() self.self_attn = FlashGemma2Attention( prefix=f"{prefix}.self_attn", config=config, weights=weights, + layer_id=layer_id, causal=causal, is_sliding=is_sliding, ) - self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights) + self.mlp = Gemma2MLP( + prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id + ) self.input_layernorm = Gemma2FastRMSNorm.load( prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps @@ -344,6 +394,7 @@ class FlashGemma2Layer(nn.Module): slots, seqlen, max_s, + adapter_data, ): normed_hidden_states, res = self.input_layernorm(hidden_states, residual) @@ -358,6 +409,7 @@ class FlashGemma2Layer(nn.Module): slots, seqlen, max_s, + adapter_data, ) # faster post attention rms norm @@ -366,7 +418,7 @@ class FlashGemma2Layer(nn.Module): res = normed_attn_res_output pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output) - mlp_output = self.mlp(pre_normed) + mlp_output = self.mlp(pre_normed, adapter_data) post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output) return post_hidden_states, normed_attn_res_output @@ -385,6 +437,7 @@ class FlashGemma2Model(torch.nn.Module): prefix=f"{prefix}.layers.{layer_id}", config=config, weights=weights, + layer_id=layer_id, causal=causal, is_sliding=layer_id % 2 == 0, ) @@ -409,6 +462,7 @@ class FlashGemma2Model(torch.nn.Module): slots: torch.Tensor, seqlen: Seqlen, max_s: int, + adapter_data: Optional[torch.Tensor] = None, ) -> torch.Tensor: hidden_states = inputs_embeds @@ -431,6 +485,7 @@ class FlashGemma2Model(torch.nn.Module): slots, seqlen, max_s, + adapter_data, ) hidden_states, _ = self.norm(hidden_states, residual) @@ -492,6 +547,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module): slots, seqlen, max_s, + adapter_data, ) if lm_head_indices is not None: hidden_states = hidden_states[lm_head_indices]