Add LoRA adapters support for Gemma2 (#2567)
* Add LoRA adapters support for Gemma2 * Make `black` formatting happy
This commit is contained in:
parent
7efcb5e0ed
commit
0b7df77178
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
||||||
TensorParallelEmbedding,
|
TensorParallelEmbedding,
|
||||||
SpeculativeHead,
|
SpeculativeHead,
|
||||||
get_linear,
|
get_linear,
|
||||||
|
TensorParallelMultiAdapterLinear,
|
||||||
|
TensorParallelAdapterRowLinear,
|
||||||
)
|
)
|
||||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||||
from text_generation_server.layers.layernorm import (
|
from text_generation_server.layers.layernorm import (
|
||||||
|
@ -161,7 +163,9 @@ def _load_gqa(config, prefix: str, weights):
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2Attention(torch.nn.Module):
|
class FlashGemma2Attention(torch.nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.num_heads = config.num_attention_heads
|
self.num_heads = config.num_attention_heads
|
||||||
self.head_size = config.head_dim
|
self.head_size = config.head_dim
|
||||||
|
@ -192,14 +196,32 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
)
|
)
|
||||||
self.softcap = config.attn_logit_softcapping
|
self.softcap = config.attn_logit_softcapping
|
||||||
|
|
||||||
self.query_key_value = load_attention(config, prefix, weights)
|
query_key_value = load_attention(config, prefix, weights)
|
||||||
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
|
query_key_value,
|
||||||
|
layer_id,
|
||||||
|
["q_proj", "k_proj", "v_proj"],
|
||||||
|
sizes=[
|
||||||
|
self.head_size * config.num_attention_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
self.head_size * config.num_key_value_heads,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.o_proj = TensorParallelRowLinear.load(
|
o_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.o_proj",
|
prefix=f"{prefix}.o_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
o_proj,
|
||||||
|
layer_id,
|
||||||
|
"o_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||||
self.kv_head_mapping = torch.arange(
|
self.kv_head_mapping = torch.arange(
|
||||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||||
|
@ -216,8 +238,9 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
qkv = self.query_key_value(hidden_states)
|
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||||
query, kv = qkv.split(
|
query, kv = qkv.split(
|
||||||
[
|
[
|
||||||
self.head_size * self.num_heads,
|
self.head_size * self.num_heads,
|
||||||
|
@ -260,11 +283,13 @@ class FlashGemma2Attention(torch.nn.Module):
|
||||||
softcap=self.softcap,
|
softcap=self.softcap,
|
||||||
)
|
)
|
||||||
|
|
||||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
return self.o_proj(
|
||||||
|
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class Gemma2MLP(nn.Module):
|
class Gemma2MLP(nn.Module):
|
||||||
def __init__(self, prefix, config, weights):
|
def __init__(self, prefix, config, weights, layer_id):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
act = config.hidden_activation
|
act = config.hidden_activation
|
||||||
self.act = (
|
self.act = (
|
||||||
|
@ -278,40 +303,65 @@ class Gemma2MLP(nn.Module):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
# Fuse gate and up proj
|
# Fuse gate and up proj
|
||||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||||
config,
|
config,
|
||||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||||
weights=weights,
|
weights=weights,
|
||||||
dim=0,
|
dim=0,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
self.down_proj = TensorParallelRowLinear.load(
|
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||||
|
gate_up_proj,
|
||||||
|
layer_id,
|
||||||
|
["gate_proj", "up_proj"],
|
||||||
|
sizes=[
|
||||||
|
config.intermediate_size,
|
||||||
|
config.intermediate_size,
|
||||||
|
],
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
|
down_proj = TensorParallelRowLinear.load(
|
||||||
config,
|
config,
|
||||||
prefix=f"{prefix}.down_proj",
|
prefix=f"{prefix}.down_proj",
|
||||||
weights=weights,
|
weights=weights,
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||||
|
down_proj,
|
||||||
|
layer_id,
|
||||||
|
"down_proj",
|
||||||
|
process_group=weights.process_group,
|
||||||
|
)
|
||||||
|
|
||||||
self.intermediate_size = (
|
self.intermediate_size = (
|
||||||
config.intermediate_size // weights.process_group.size()
|
config.intermediate_size // weights.process_group.size()
|
||||||
)
|
)
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states, adapter_data):
|
||||||
gate_up_states = self.gate_up_proj(hidden_states)
|
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
return self.down_proj(
|
||||||
|
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class FlashGemma2Layer(nn.Module):
|
class FlashGemma2Layer(nn.Module):
|
||||||
def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool):
|
def __init__(
|
||||||
|
self, prefix: str, config, weights, layer_id, causal: bool, is_sliding: bool
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.self_attn = FlashGemma2Attention(
|
self.self_attn = FlashGemma2Attention(
|
||||||
prefix=f"{prefix}.self_attn",
|
prefix=f"{prefix}.self_attn",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
is_sliding=is_sliding,
|
is_sliding=is_sliding,
|
||||||
)
|
)
|
||||||
self.mlp = Gemma2MLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
self.mlp = Gemma2MLP(
|
||||||
|
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||||
|
)
|
||||||
|
|
||||||
self.input_layernorm = Gemma2FastRMSNorm.load(
|
self.input_layernorm = Gemma2FastRMSNorm.load(
|
||||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||||
|
@ -344,6 +394,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
):
|
):
|
||||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||||
|
|
||||||
|
@ -358,6 +409,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
# faster post attention rms norm
|
# faster post attention rms norm
|
||||||
|
@ -366,7 +418,7 @@ class FlashGemma2Layer(nn.Module):
|
||||||
res = normed_attn_res_output
|
res = normed_attn_res_output
|
||||||
|
|
||||||
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
pre_normed, _ = self.pre_feedforward_layernorm(normed_attn_res_output)
|
||||||
mlp_output = self.mlp(pre_normed)
|
mlp_output = self.mlp(pre_normed, adapter_data)
|
||||||
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
post_hidden_states, _ = self.post_feedforward_layernorm(mlp_output)
|
||||||
|
|
||||||
return post_hidden_states, normed_attn_res_output
|
return post_hidden_states, normed_attn_res_output
|
||||||
|
@ -385,6 +437,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
prefix=f"{prefix}.layers.{layer_id}",
|
prefix=f"{prefix}.layers.{layer_id}",
|
||||||
config=config,
|
config=config,
|
||||||
weights=weights,
|
weights=weights,
|
||||||
|
layer_id=layer_id,
|
||||||
causal=causal,
|
causal=causal,
|
||||||
is_sliding=layer_id % 2 == 0,
|
is_sliding=layer_id % 2 == 0,
|
||||||
)
|
)
|
||||||
|
@ -409,6 +462,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
slots: torch.Tensor,
|
slots: torch.Tensor,
|
||||||
seqlen: Seqlen,
|
seqlen: Seqlen,
|
||||||
max_s: int,
|
max_s: int,
|
||||||
|
adapter_data: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
hidden_states = inputs_embeds
|
hidden_states = inputs_embeds
|
||||||
|
|
||||||
|
@ -431,6 +485,7 @@ class FlashGemma2Model(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
|
|
||||||
hidden_states, _ = self.norm(hidden_states, residual)
|
hidden_states, _ = self.norm(hidden_states, residual)
|
||||||
|
@ -492,6 +547,7 @@ class FlashGemma2ForCausalLM(torch.nn.Module):
|
||||||
slots,
|
slots,
|
||||||
seqlen,
|
seqlen,
|
||||||
max_s,
|
max_s,
|
||||||
|
adapter_data,
|
||||||
)
|
)
|
||||||
if lm_head_indices is not None:
|
if lm_head_indices is not None:
|
||||||
hidden_states = hidden_states[lm_head_indices]
|
hidden_states = hidden_states[lm_head_indices]
|
||||||
|
|
Loading…
Reference in New Issue