fix: prefer adapter_data and refactors

This commit is contained in:
drbh 2024-06-06 14:35:59 +00:00
parent 8b50f4b779
commit d5f21d57d1
3 changed files with 5 additions and 80 deletions

View File

@ -117,8 +117,6 @@ class FlashLlamaAttention(torch.nn.Module):
prefix: str, prefix: str,
config, config,
weights, weights,
lora_weights,
lora_configs,
): ):
super().__init__() super().__init__()
self.num_heads = config.num_attention_heads self.num_heads = config.num_attention_heads
@ -151,40 +149,6 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index) self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index self.index = index
self.adapter_weights = {}
adapter_names = list(lora_weights.keys())
self.n_loras = len(adapter_names)
self.pre_multiplied_lora_matrices = torch.empty(
(self.n_loras, 2, self.hidden_size, self.hidden_size),
device=weights.device,
dtype=weights.dtype,
)
self.key_to_index = {}
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
for adapter_index, adapter_name in enumerate(adapter_names):
self.lora_alpha = lora_configs[adapter_name].lora_alpha
self.lora_r = lora_configs[adapter_name].r
self.lora_scale = self.lora_alpha / self.lora_r
self.key_to_index[adapter_name] = adapter_index
adapter_weights = lora_weights[adapter_name]
for target_index, target in enumerate(["q", "v"]):
adapter_weight_a = adapter_weights.get_tensor(
f"{lora_prefix}.{target}_proj.lora_A.weight"
)
adapter_weight_b = adapter_weights.get_tensor(
f"{lora_prefix}.{target}_proj.lora_B.weight"
)
pre_multiplied_lora_matrix = torch.matmul(
adapter_weight_a.T * self.lora_scale,
adapter_weight_b.T,
).contiguous()
self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = (
pre_multiplied_lora_matrix
)
o_proj = TensorParallelRowLinear.load( o_proj = TensorParallelRowLinear.load(
config, config,
@ -216,8 +180,6 @@ class FlashLlamaAttention(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data, adapter_data,
): ):
qkv = self.query_key_value(hidden_states, adapter_data) qkv = self.query_key_value(hidden_states, adapter_data)
@ -355,15 +317,13 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module): class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights, lora_weights, lora_configs): def __init__(self, index, prefix, config, weights):
super().__init__() super().__init__()
self.self_attn = FlashLlamaAttention( self.self_attn = FlashLlamaAttention(
index=index, index=index,
prefix=f"{prefix}.self_attn", prefix=f"{prefix}.self_attn",
config=config, config=config,
weights=weights, weights=weights,
lora_weights=lora_weights,
lora_configs=lora_configs,
) )
self.mlp = LlamaMLP( self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
@ -390,8 +350,6 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data, adapter_data,
): ):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual) normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -407,8 +365,6 @@ class FlashLlamaLayer(nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data, adapter_data,
) )
@ -423,7 +379,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module): class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, lora_weights, lora_configs): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
process_group = weights.process_group process_group = weights.process_group
@ -440,8 +396,6 @@ class FlashLlamaModel(torch.nn.Module):
), ),
config=config, config=config,
weights=weights, weights=weights,
lora_weights=lora_weights,
lora_configs=lora_configs,
) )
for layer_id in range(config.num_hidden_layers) for layer_id in range(config.num_hidden_layers)
] ]
@ -470,8 +424,6 @@ class FlashLlamaModel(torch.nn.Module):
max_s: int, max_s: int,
true_max_s: int, true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor], prefill_cache_indices: Optional[torch.Tensor],
batch_lora_adapter_mask: Optional[List[str]],
lora_indices: Optional[torch.Tensor],
adapter_data, adapter_data,
) -> torch.Tensor: ) -> torch.Tensor:
hidden_states = inputs_embeds hidden_states = inputs_embeds
@ -495,8 +447,6 @@ class FlashLlamaModel(torch.nn.Module):
slots, slots,
input_lengths, input_lengths,
max_s, max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data, adapter_data,
) )
@ -506,7 +456,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module): class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, lora_weights, lora_configs): def __init__(self, prefix, config, weights):
super().__init__() super().__init__()
self.embed_tokens = TensorParallelEmbedding( self.embed_tokens = TensorParallelEmbedding(
@ -515,9 +465,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
), ),
weights=weights, weights=weights,
) )
self.model = FlashLlamaModel( self.model = FlashLlamaModel(prefix, config, weights)
prefix, config, weights, lora_weights, lora_configs
)
if config.tie_word_embeddings: if config.tie_word_embeddings:
suffix = "model.embed_tokens" suffix = "model.embed_tokens"
else: else:
@ -544,8 +492,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s: int, max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None, prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None, lm_head_indices: Optional[torch.Tensor] = None,
batch_lora_adapter_mask: Optional[List[str]] = None,
lora_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None, adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids) inputs_embeds = self.embed_tokens(input_ids)
@ -560,8 +506,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s, max_s,
true_max_s=max_s, true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices, prefill_cache_indices=prefill_cache_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
) )
if lm_head_indices is not None: if lm_head_indices is not None:

View File

@ -877,8 +877,6 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph() graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph self.cuda_graphs[bs]["graph"] = graph
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
torch.cuda.synchronize() torch.cuda.synchronize()
# Run once outside to warmup # Run once outside to warmup
self.model.forward( self.model.forward(
@ -892,8 +890,6 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
torch.cuda.synchronize() torch.cuda.synchronize()
@ -909,8 +905,6 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
prefill_cache_indices=None, prefill_cache_indices=None,
lm_head_indices=None, lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
self.cuda_graphs[bs]["logits"] = logits self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
@ -1038,10 +1032,6 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`. # Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device) input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
batch_lora_adapter_mask = torch.zeros(
seqlen, dtype=torch.bool, device=self.device
)
lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation. # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward( self.model.forward(
@ -1057,8 +1047,6 @@ class FlashCausalLM(Model):
max_s=seqlen, max_s=seqlen,
lm_head_indices=None, lm_head_indices=None,
prefill_cache_indices=None, prefill_cache_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
) )
def forward( def forward(
@ -1129,9 +1117,6 @@ class FlashCausalLM(Model):
else: else:
cuda_graph = None cuda_graph = None
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
if cu_seqlen_prefill is not None or cuda_graph is None: if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward( logits, speculative_logits = self.model.forward(
input_ids=input_ids, input_ids=input_ids,
@ -1144,8 +1129,6 @@ class FlashCausalLM(Model):
max_s=max_s, max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices, prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices, lm_head_indices=lm_head_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
adapter_data=adapter_data, adapter_data=adapter_data,
) )
if batch.prefill_cache_indices is not None: if batch.prefill_cache_indices is not None:

View File

@ -99,9 +99,7 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision) weights._set_gptq_params(model_id, revision)
prefix = "" prefix = ""
model = FlashLlamaForCausalLM( model = FlashLlamaForCausalLM(prefix, config, weights)
prefix, config, weights, lora_weights, lora_configs
)
torch.distributed.barrier(group=self.process_group) torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__( super(FlashLlama, self).__init__(
model_id=model_id, model_id=model_id,