feat: add lora support to mistral and refactors

This commit is contained in:
drbh 2024-06-06 22:44:58 +00:00
parent 9c45d34983
commit de56a81c5c
4 changed files with 160 additions and 27 deletions

View File

@ -168,9 +168,12 @@ def download_weights(
except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError): except (utils.LocalEntryNotFoundError, utils.EntryNotFoundError):
pass pass
else: else:
try:
utils.peft.download_peft( utils.peft.download_peft(
model_id, revision, trust_remote_code=trust_remote_code model_id, revision, trust_remote_code=trust_remote_code
) )
except Exception:
pass
try: try:
import json import json

View File

@ -38,6 +38,8 @@ from text_generation_server.layers import (
TensorParallelEmbedding, TensorParallelEmbedding,
SpeculativeHead, SpeculativeHead,
get_linear, get_linear,
TensorParallelMultiAdapterLinear,
TensorParallelAdapterRowLinear,
) )
from text_generation_server.layers.rotary import PositionRotaryEmbedding from text_generation_server.layers.rotary import PositionRotaryEmbedding
from text_generation_server.layers.layernorm import ( from text_generation_server.layers.layernorm import (
@ -107,12 +109,7 @@ class MistralConfig(PretrainedConfig):
class MistralAttention(torch.nn.Module): class MistralAttention(torch.nn.Module):
def __init__( def __init__(self, prefix: str, config, weights, layer_id):
self,
prefix: str,
config,
weights,
):
super().__init__() super().__init__()
self.max_past = ( self.max_past = (
config.sliding_window if config.sliding_window is not None else -1 config.sliding_window if config.sliding_window is not None else -1
@ -140,7 +137,7 @@ class MistralAttention(torch.nn.Module):
config.num_key_value_heads // weights.process_group.size() config.num_key_value_heads // weights.process_group.size()
) )
self.query_key_value = TensorParallelColumnLinear.load_multi( query_key_value = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"], prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0, dim=0,
@ -148,12 +145,31 @@ class MistralAttention(torch.nn.Module):
bias=False, bias=False,
) )
self.o_proj = TensorParallelRowLinear.load( head_size = config.hidden_size // config.num_attention_heads
self.query_key_value = TensorParallelMultiAdapterLinear.load(
query_key_value,
layer_id,
["q_proj", "k_proj", "v_proj"],
sizes=[
head_size * config.num_attention_heads,
head_size * config.num_key_value_heads,
head_size * config.num_key_value_heads,
],
process_group=weights.process_group,
)
o_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.o_proj", prefix=f"{prefix}.o_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.o_proj = TensorParallelAdapterRowLinear.load(
o_proj,
layer_id,
"o_proj",
process_group=weights.process_group,
)
self.num_groups = self.num_heads // self.num_key_value_heads self.num_groups = self.num_heads // self.num_key_value_heads
self.kv_head_mapping = torch.arange( self.kv_head_mapping = torch.arange(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device 0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
@ -171,8 +187,9 @@ class MistralAttention(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
qkv = self.query_key_value(hidden_states) qkv = self.query_key_value(hidden_states, adapter_data)
query, kv = qkv.split( query, kv = qkv.split(
[ [
self.head_size * self.num_heads, self.head_size * self.num_heads,
@ -224,11 +241,13 @@ class MistralAttention(torch.nn.Module):
max_s, max_s,
) )
return self.o_proj(attn_output.view(-1, self.num_heads * self.head_size)) return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)
class MistralMLP(nn.Module): class MistralMLP(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, layer_id):
super().__init__() super().__init__()
self.hidden_act = config.hidden_act self.hidden_act = config.hidden_act
self.act = ( self.act = (
@ -244,19 +263,37 @@ class MistralMLP(nn.Module):
) )
) )
# Fuse gate and up proj # Fuse gate and up proj
self.gate_up_proj = TensorParallelColumnLinear.load_multi( gate_up_proj = TensorParallelColumnLinear.load_multi(
config, config,
prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"], prefixes=[f"{prefix}.gate_proj", f"{prefix}.up_proj"],
weights=weights, weights=weights,
dim=0, dim=0,
bias=False, bias=False,
) )
self.down_proj = TensorParallelRowLinear.load( self.gate_up_proj = TensorParallelMultiAdapterLinear.load(
gate_up_proj,
layer_id,
["gate_proj", "up_proj"],
sizes=[
config.intermediate_size,
config.intermediate_size,
],
process_group=weights.process_group,
)
down_proj = TensorParallelRowLinear.load(
config, config,
prefix=f"{prefix}.down_proj", prefix=f"{prefix}.down_proj",
weights=weights, weights=weights,
bias=False, bias=False,
) )
self.down_proj = TensorParallelAdapterRowLinear.load(
down_proj,
layer_id,
"down_proj",
process_group=weights.process_group,
)
self.intermediate_size = ( self.intermediate_size = (
config.intermediate_size // weights.process_group.size() config.intermediate_size // weights.process_group.size()
) )
@ -264,7 +301,7 @@ class MistralMLP(nn.Module):
# TODO: This is a hotfix to be removed & properly refactored. # TODO: This is a hotfix to be removed & properly refactored.
self.quantize = config.quantize self.quantize = config.quantize
def forward(self, hidden_states): def forward(self, hidden_states, adapter_data):
if ( if (
SYSTEM == "rocm" SYSTEM == "rocm"
and self.hidden_act == "silu" and self.hidden_act == "silu"
@ -278,20 +315,27 @@ class MistralMLP(nn.Module):
device="cuda", device="cuda",
) )
_custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8) _custom_C.LLMM_Silu(self.gate_up_proj.linear.weight, hidden_states, out, 8)
return self.down_proj(out) return self.down_proj(out, adapter_data)
else: else:
gate_up_states = self.gate_up_proj(hidden_states) gate_up_states = self.gate_up_proj(hidden_states, adapter_data)
gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size) gate_up_states = gate_up_states.view(-1, 2, self.intermediate_size)
return self.down_proj(self.act(gate_up_states[:, 0]) * gate_up_states[:, 1]) return self.down_proj(
self.act(gate_up_states[:, 0]) * gate_up_states[:, 1], adapter_data
)
class MistralLayer(nn.Module): class MistralLayer(nn.Module):
def __init__(self, prefix, config, weights): def __init__(self, prefix, config, weights, layer_id):
super().__init__() super().__init__()
self.self_attn = MistralAttention( self.self_attn = MistralAttention(
prefix=f"{prefix}.self_attn", config=config, weights=weights prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_id=layer_id,
)
self.mlp = MistralMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, layer_id=layer_id
) )
self.mlp = MistralMLP(prefix=f"{prefix}.mlp", config=config, weights=weights)
self.input_layernorm = FastRMSNorm.load( self.input_layernorm = FastRMSNorm.load(
prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps prefix=f"{prefix}.input_layernorm", weights=weights, eps=config.rms_norm_eps
@ -315,6 +359,7 @@ class MistralLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -330,6 +375,7 @@ class MistralLayer(nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
# faster post attention rms norm # faster post attention rms norm
@ -337,7 +383,7 @@ class MistralLayer(nn.Module):
attn_output, res attn_output, res
) )
mlp_output = self.mlp(normed_attn_res_output) mlp_output = self.mlp(normed_attn_res_output, adapter_data)
return mlp_output, attn_res return mlp_output, attn_res
@ -355,6 +401,7 @@ class MistralModel(torch.nn.Module):
prefix=f"{prefix}.layers.{layer_id}", prefix=f"{prefix}.layers.{layer_id}",
config=config, config=config,
weights=weights, weights=weights,
layer_id=layer_id,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
@ -381,6 +428,7 @@ class MistralModel(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
adapter_data: Optional[torch.Tensor] = None,
): ):
hidden_states = inputs_embeds hidden_states = inputs_embeds
# Get rotary cos and sin for this forward # Get rotary cos and sin for this forward
@ -403,6 +451,7 @@ class MistralModel(torch.nn.Module):
input_lengths, input_lengths,
max_s, max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
hidden_states, _ = self.norm(hidden_states, residual) hidden_states, _ = self.norm(hidden_states, residual)
@ -454,6 +503,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> torch.Tensor: ) -> torch.Tensor:
true_max_s = max_s true_max_s = max_s
if prefill_cache_indices is not None: if prefill_cache_indices is not None:
@ -476,6 +526,7 @@ class FlashMistralForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s, true_max_s,
prefill_cache_indices, prefill_cache_indices,
adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:
hidden_states = hidden_states[lm_head_indices] hidden_states = hidden_states[lm_head_indices]

View File

@ -3,7 +3,7 @@ import torch.distributed
from opentelemetry import trace from opentelemetry import trace
from transformers import AutoTokenizer, AutoConfig from transformers import AutoTokenizer, AutoConfig
from typing import Optional, Tuple from typing import Optional, Tuple, Dict, List
from text_generation_server.models import FlashCausalLM from text_generation_server.models import FlashCausalLM
from text_generation_server.models.flash_causal_lm import set_sliding_window from text_generation_server.models.flash_causal_lm import set_sliding_window
@ -21,6 +21,31 @@ from text_generation_server.utils.import_utils import SYSTEM
tracer = trace.get_tracer(__name__) tracer = trace.get_tracer(__name__)
Q_PROJ = "q_proj"
K_PROJ = "k_proj"
V_PROJ = "v_proj"
O_PROJ = "o_proj"
GATE_PROJ = "gate_proj"
UP_PROJ = "up_proj"
DOWN_PROJ = "down_proj"
LM_HEAD = "lm_head"
# TODO(travis): re-enable LM_HEAD after resolving issues with outputs
ADAPTER_LAYERS = [
Q_PROJ,
K_PROJ,
V_PROJ,
O_PROJ,
GATE_PROJ,
UP_PROJ,
DOWN_PROJ,
] # LM_HEAD
ROW_PARALLEL = {O_PROJ, DOWN_PROJ, LM_HEAD}
class BaseFlashMistral(FlashCausalLM): class BaseFlashMistral(FlashCausalLM):
def __init__( def __init__(
self, self,
@ -99,6 +124,62 @@ class BaseFlashMistral(FlashCausalLM):
model.model.head_size, model.model.head_size,
) )
@property
def supports_adapter_loading(self) -> bool:
return True
def adapter_target_to_layer(self) -> Dict[str, Tuple[str, torch.Tensor]]:
layer_weights = {}
prefix = "model.layers"
for i, layer in enumerate(self.model.model.layers):
layer_weights[(i, Q_PROJ)] = (
f"{prefix}.{i}.self_attn.q_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, K_PROJ)] = (
f"{prefix}.{i}.self_attn.k_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, V_PROJ)] = (
f"{prefix}.{i}.self_attn.v_proj",
layer.self_attn.query_key_value,
)
layer_weights[(i, O_PROJ)] = (
f"{prefix}.{i}.self_attn.o_proj",
layer.self_attn.o_proj,
)
layer_weights[(i, GATE_PROJ)] = (
f"{prefix}.{i}.mlp.gate_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, UP_PROJ)] = (
f"{prefix}.{i}.mlp.up_proj",
layer.mlp.gate_up_proj,
)
layer_weights[(i, DOWN_PROJ)] = (
f"{prefix}.{i}.mlp.down_proj",
layer.mlp.down_proj,
)
layer_weights[(0, LM_HEAD)] = ("lm_head", self.model.lm_head)
return layer_weights
@property
def adapter_layers(self) -> List[str]:
return ADAPTER_LAYERS
@property
def default_traced_adapter_layers(self) -> List[str]:
return [Q_PROJ, V_PROJ]
def get_num_layers_for_type(self, layer_type: str) -> int:
return 1 if layer_type == LM_HEAD else len(self.model.model.layers)
def is_row_parallel(self, layer_type: str) -> bool:
return layer_type in ROW_PARALLEL
class FlashMistral(BaseFlashMistral): class FlashMistral(BaseFlashMistral):
def __init__( def __init__(
@ -111,7 +192,6 @@ class FlashMistral(BaseFlashMistral):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
super(FlashMistral, self).__init__( super(FlashMistral, self).__init__(
model_id=model_id,
config_cls=MistralConfig, config_cls=MistralConfig,
model_cls=FlashMistralForCausalLM, model_cls=FlashMistralForCausalLM,
model_id=model_id, model_id=model_id,

View File

@ -20,7 +20,6 @@ class FlashMixtral(BaseFlashMistral):
trust_remote_code: bool = False, trust_remote_code: bool = False,
): ):
super(FlashMixtral, self).__init__( super(FlashMixtral, self).__init__(
model_id=model_id,
config_cls=MixtralConfig, config_cls=MixtralConfig,
model_cls=FlashMixtralForCausalLM, model_cls=FlashMixtralForCausalLM,
model_id=model_id, model_id=model_id,