feat: add lora support to mistral and refactors
This commit is contained in:
parent
9c45d34983
commit
de56a81c5c
|
@ -168,9 +168,12 @@ def download_weights(
|
|||
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
|
||||
pass
|
||||
else:
|
||||
utils.peft.download_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
try:
|
||||
utils.peft.download_peft(
|
||||
model_id, revision, trust_remote_code=trust_remote_code
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
import json
|
||||
|
|
|
@ -38,6 +38,8 @@ from text_generation_server.layers import (
|
|||
TensorParallelEmbedding,
|
||||
SpeculativeHead,
|
||||
get_linear,
|
||||
TensorParallelMultiAdapterLinear,
|
||||
TensorParallelAdapterRowLinear,
|
||||
)
|
||||
from text_generation_server.layers.rotary import PositionRotaryEmbedding
|
||||
from text_generation_server.layers.layernorm import (
|
||||
|
@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
|
|||
|
||||
|
||||
class MistralAttention(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
):
|
||||
def __init__(self, prefix: str, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.max_past = (
|
||||
config.sliding_window if config.sliding_window is not None else -1
|
||||
|
@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
|
|||
config.num_key_value_heads // weights.process_group.size()
|
||||
)
|
||||
|
||||
self.query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
query_key_value = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
|
||||
dim=0,
|
||||
|
@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
|
|||
bias=False,
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||
query_key_value,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.o_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
self.o_proj = TensorParallelAdapterRowLinear.load(
|
||||
o_proj,
|
||||
layer_id,
|
||||
"o_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.num_groups = self.num_heads // self.num_key_value_heads
|
||||
self.kv_head_mapping = torch.arange(
|
||||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
|
@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states)
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
query, kv = qkv.split(
|
||||
[
|
||||
self.head_size * self.num_heads,
|
||||
|
@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
|
|||
max_s,
|
||||
)
|
||||
|
||||
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size))
|
||||
return self.o_proj(
|
||||
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
|
||||
)
|
||||
|
||||
|
||||
class MistralMLP(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.hidden_act = config.hidden_act
|
||||
self.act = (
|
||||
|
@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
|
|||
)
|
||||
)
|
||||
# Fuse gate and up proj
|
||||
self.gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
gate_up_proj = TensorParallelColumnLinear.load_multi(
|
||||
config,
|
||||
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
|
||||
weights=weights,
|
||||
dim=0,
|
||||
bias=False,
|
||||
)
|
||||
self.down_proj = TensorParallelRowLinear.load(
|
||||
self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
|
||||
gate_up_proj,
|
||||
layer_id,
|
||||
["gate_proj", "up_proj"],
|
||||
sizes=[
|
||||
config.intermediate_size,
|
||||
config.intermediate_size,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
||||
down_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
prefix=f"{prefix}.down_proj",
|
||||
weights=weights,
|
||||
bias=False,
|
||||
)
|
||||
|
||||
self.down_proj = TensorParallelAdapterRowLinear.load(
|
||||
down_proj,
|
||||
layer_id,
|
||||
"down_proj",
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
self.intermediate_size = (
|
||||
config.intermediate_size // weights.process_group.size()
|
||||
)
|
||||
|
@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
|
|||
# TODO: This is a hotfix to be removed & properly refactored.
|
||||
self.quantize = config.quantize
|
||||
|
||||
def forward(self, hidden_states):
|
||||
def forward(self, hidden_states, adapter_data):
|
||||
if (
|
||||
SYSTEM == "rocm"
|
||||
and self.hidden_act == "silu"
|
||||
|
@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
|
|||
device="cuda",
|
||||
)
|
||||
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
|
||||
return self.down_proj(out)
|
||||
return self.down_proj(out, adapter_data)
|
||||
else:
|
||||
gate_up_states = self.gate_up_proj(hidden_states)
|
||||
gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
|
||||
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
|
||||
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1])
|
||||
return self.down_proj(
|
||||
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
|
||||
)
|
||||
|
||||
|
||||
class MistralLayer(nn.Module):
|
||||
def __init__(self, prefix, config, weights):
|
||||
def __init__(self, prefix, config, weights, layer_id):
|
||||
super().__init__()
|
||||
self.self_attn = MistralAttention(
|
||||
prefix=f"{prefix}.self_attn", config=config, weights=weights
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
self.mlp = MistralMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
|
||||
)
|
||||
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
|
||||
|
||||
self.input_layernorm = FastRMSNorm.load(
|
||||
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
|
||||
|
@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
||||
|
@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
# faster post attention rms norm
|
||||
|
@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
|
|||
attn_output, res
|
||||
)
|
||||
|
||||
mlp_output = self.mlp(normed_attn_res_output)
|
||||
mlp_output = self.mlp(normed_attn_res_output, adapter_data)
|
||||
|
||||
return mlp_output, attn_res
|
||||
|
||||
|
@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
|
|||
prefix=f"{prefix}.layers.{layer_id}",
|
||||
config=config,
|
||||
weights=weights,
|
||||
layer_id=layer_id,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
|
|||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
):
|
||||
hidden_states = inputs_embeds
|
||||
# Get rotary cos and sin for this forward
|
||||
|
@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
|
|||
input_lengths,
|
||||
max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
hidden_states, _ = self.norm(hidden_states, residual)
|
||||
|
@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
true_max_s = max_s
|
||||
if prefill_cache_indices is not None:
|
||||
|
@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
|
|||
max_s,
|
||||
true_max_s,
|
||||
prefill_cache_indices,
|
||||
adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
hidden_states = hidden_states[lm_head_indices]
|
||||
|
|
|
@ -3,7 +3,7 @@ import torch.distributed
|
|||
|
||||
from opentelemetry import trace
|
||||
from transformers import AutoTokenizer, AutoConfig
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional, Tuple, Dict, List
|
||||
|
||||
from text_generation_server.models import FlashCausalLM
|
||||
from text_generation_server.models.flash_causal_lm import set_sliding_window
|
||||
|
@ -21,6 +21,31 @@ from text_generation_server.utils.import_utils import SYSTEM
|
|||
tracer = trace.get_tracer(__name__)
|
||||
|
||||
|
||||
Q_PROJ = "q_proj"
|
||||
K_PROJ = "k_proj"
|
||||
V_PROJ = "v_proj"
|
||||
O_PROJ = "o_proj"
|
||||
|
||||
GATE_PROJ = "gate_proj"
|
||||
UP_PROJ = "up_proj"
|
||||
DOWN_PROJ = "down_proj"
|
||||
|
||||
LM_HEAD = "lm_head"
|
||||
|
||||
|
||||
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
|
||||
ADAPTER_LAYERS = [
|
||||
Q_PROJ,
|
||||
K_PROJ,
|
||||
V_PROJ,
|
||||
O_PROJ,
|
||||
GATE_PROJ,
|
||||
UP_PROJ,
|
||||
DOWN_PROJ,
|
||||
] # LM_HEAD
|
||||
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
|
||||
|
||||
|
||||
class BaseFlashMistral(FlashCausalLM):
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -99,6 +124,62 @@ class BaseFlashMistral(FlashCausalLM):
|
|||
model.model.head_size,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_adapter_loading(self) -> bool:
|
||||
return True
|
||||
|
||||
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
|
||||
layer_weights = {}
|
||||
|
||||
prefix = "model.layers"
|
||||
for i, layer in enumerate(self.model.model.layers):
|
||||
layer_weights[(i, Q_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.q_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, K_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.k_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, V_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.v_proj",
|
||||
layer.self_attn.query_key_value,
|
||||
)
|
||||
layer_weights[(i, O_PROJ)] = (
|
||||
f"{prefix}.{i}.self_attn.o_proj",
|
||||
layer.self_attn.o_proj,
|
||||
)
|
||||
|
||||
layer_weights[(i, GATE_PROJ)] = (
|
||||
f"{prefix}.{i}.mlp.gate_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, UP_PROJ)] = (
|
||||
f"{prefix}.{i}.mlp.up_proj",
|
||||
layer.mlp.gate_up_proj,
|
||||
)
|
||||
layer_weights[(i, DOWN_PROJ)] = (
|
||||
f"{prefix}.{i}.mlp.down_proj",
|
||||
layer.mlp.down_proj,
|
||||
)
|
||||
|
||||
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
|
||||
return layer_weights
|
||||
|
||||
@property
|
||||
def adapter_layers(self) -> List[str]:
|
||||
return ADAPTER_LAYERS
|
||||
|
||||
@property
|
||||
def default_traced_adapter_layers(self) -> List[str]:
|
||||
return [Q_PROJ, V_PROJ]
|
||||
|
||||
def get_num_layers_for_type(self, layer_type: str) -> int:
|
||||
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
|
||||
|
||||
def is_row_parallel(self, layer_type: str) -> bool:
|
||||
return layer_type in ROW_PARALLEL
|
||||
|
||||
|
||||
class FlashMistral(BaseFlashMistral):
|
||||
def __init__(
|
||||
|
@ -111,7 +192,6 @@ class FlashMistral(BaseFlashMistral):
|
|||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMistral, self).__init__(
|
||||
model_id=model_id,
|
||||
config_cls=MistralConfig,
|
||||
model_cls=FlashMistralForCausalLM,
|
||||
model_id=model_id,
|
||||
|
|
|
@ -20,7 +20,6 @@ class FlashMixtral(BaseFlashMistral):
|
|||
trust_remote_code: bool = False,
|
||||
):
|
||||
super(FlashMixtral, self).__init__(
|
||||
model_id=model_id,
|
||||
config_cls=MixtralConfig,
|
||||
model_cls=FlashMixtralForCausalLM,
|
||||
model_id=model_id,
|
||||
|
|
Loading…
Reference in New Issue