Add LoRA adapters support for Gemma2 (#2567)

* Add LoRA adapters support for Gemma2

* Make `black` formatting happy
This commit is contained in:
Alvaro Bartolome 2024-09-26 10:54:08 +02:00 committed by GitHub
parent 7efcb5e0ed
commit 0b7df77178
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 70 additions and 14 deletions

View File

@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelEmbedding,
SpeculativeHead,
get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
)
from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import (
@ -161,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
class FlashGemma2Attention(torch.nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.num_heads = config.num_attention_heads
self.head_size = config.head_dim
@ -192,14 +196,32 @@ class FlashGemma2Attention(torch.nn.Module):
)
self.softcap = config.attn_logit_softcapping
self.query_key_value = load_attention(config, prefix, weights)
query_key_value = load_attention(config, prefix, weights)
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
self.head_size * config.num_attention_heads,
self.head_size * config.num_key_value_heads,
self.head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
self.o_proj = TensorParallelRowLinear.load(
o_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.o_proj",
weights=weights,
bias=False,
)
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -216,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module):
slots,
seqlen,
max_s,
adapter_data,
):
qkv = self.query_key_value(hidden_states)
qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split(
[
self.head_size * self.num_heads,
@ -260,11 +283,13 @@ class FlashGemma2Attention(torch.nn.Module):
softcap=self.softcap,
)
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class Gemma2MLP(nn.Module):
def __init__(self, prefix, config, weights):
def __init__(self, prefix, config, weights, layer_id):
super().__init__()
act = config.hidden_activation
self.act = (
@ -278,40 +303,65 @@ class Gemma2MLP(nn.Module):
)
)
# Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
gate_up_proj = TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights,
dim=0,
bias=False,
)
self.down_proj = TensorParallelRowLinear.load(
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config,
prefix=f"{prefix}.down_proj",
weights=weights,
bias=False,
)
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = (
config.intermediate_size // weights.process_group.size()
)
def forward(self, hidden_states):
gate_up_states = self.gate_up_proj(hidden_states)
def forward(self, hidden_states, adapter_data):
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class FlashGemma2Layer(nn.Module):
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
def __init__(
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
):
super().__init__()
self.self_attn = FlashGemma2Attention(
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=is_sliding,
)
self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.mlp = Gemma2MLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
)
self.input_layernorm = Gemma2FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -344,6 +394,7 @@ class FlashGemma2Layer(nn.Module):
slots,
seqlen,
max_s,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -358,6 +409,7 @@ class FlashGemma2Layer(nn.Module):
slots,
seqlen,
max_s,
adapter_data,
)
# faster post attention rms norm
@ -366,7 +418,7 @@ class FlashGemma2Layer(nn.Module):
res = normed_attn_res_output
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
mlp_output = self.mlp(pre_normed)
mlp_output = self.mlp(pre_normed, adapter_data)
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
return post_hidden_states, normed_attn_res_output
@ -385,6 +437,7 @@ class FlashGemma2Model(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}",
config=config,
weights=weights,
layer_id=layer_id,
causal=causal,
is_sliding=layer_id % 2 == 0,
)
@ -409,6 +462,7 @@ class FlashGemma2Model(torch.nn.Module):
slots: torch.Tensor,
seqlen: Seqlen,
max_s: int,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -431,6 +485,7 @@ class FlashGemma2Model(torch.nn.Module):
slots,
seqlen,
max_s,
adapter_data,
)
hidden_states, _ = self.norm(hidden_states, residual)
@ -492,6 +547,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
slots,
seqlen,
max_s,
adapter_data,
)
if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices]