fix: prefer adapter_data and refactors

This commit is contained in:
drbh 2024-06-06 14:35:59 +00:00
parent 8b50f4b779
commit d5f21d57d1
3 changed files with 5 additions and 80 deletions

View File

@ -117,8 +117,6 @@ class FlashLlamaAttention(torch.nn.Module):
prefix: str,
config,
weights,
lora_weights,
lora_configs,
):
super().__init__()
self.num_heads = config.num_attention_heads
@ -151,40 +149,6 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.adapter_weights = {}
adapter_names = list(lora_weights.keys())
self.n_loras = len(adapter_names)
self.pre_multiplied_lora_matrices = torch.empty(
(self.n_loras, 2, self.hidden_size, self.hidden_size),
device=weights.device,
dtype=weights.dtype,
)
self.key_to_index = {}
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
for adapter_index, adapter_name in enumerate(adapter_names):
self.lora_alpha = lora_configs[adapter_name].lora_alpha
self.lora_r = lora_configs[adapter_name].r
self.lora_scale = self.lora_alpha / self.lora_r
self.key_to_index[adapter_name] = adapter_index
adapter_weights = lora_weights[adapter_name]
for target_index, target in enumerate(["q", "v"]):
adapter_weight_a = adapter_weights.get_tensor(
f"{lora_prefix}.{target}_proj.lora_A.weight"
)
adapter_weight_b = adapter_weights.get_tensor(
f"{lora_prefix}.{target}_proj.lora_B.weight"
)
pre_multiplied_lora_matrix = torch.matmul(
adapter_weight_a.T * self.lora_scale,
adapter_weight_b.T,
).contiguous()
self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = (
pre_multiplied_lora_matrix
)
o_proj = TensorParallelRowLinear.load(
config,
@ -216,8 +180,6 @@ class FlashLlamaAttention(torch.nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
):
qkv = self.query_key_value(hidden_states, adapter_data)
@ -355,15 +317,13 @@ class LlamaMLP(nn.Module):
class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights, lora_weights, lora_configs):
def __init__(self, index, prefix, config, weights):
super().__init__()
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
lora_weights=lora_weights,
lora_configs=lora_configs,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
@ -390,8 +350,6 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
):
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
@ -407,8 +365,6 @@ class FlashLlamaLayer(nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
)
@ -423,7 +379,7 @@ class FlashLlamaLayer(nn.Module):
class FlashLlamaModel(torch.nn.Module):
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
def __init__(self, prefix, config, weights):
super().__init__()
process_group = weights.process_group
@ -440,8 +396,6 @@ class FlashLlamaModel(torch.nn.Module):
),
config=config,
weights=weights,
lora_weights=lora_weights,
lora_configs=lora_configs,
)
for layer_id in range(config.num_hidden_layers)
]
@ -470,8 +424,6 @@ class FlashLlamaModel(torch.nn.Module):
max_s: int,
true_max_s: int,
prefill_cache_indices: Optional[torch.Tensor],
batch_lora_adapter_mask: Optional[List[str]],
lora_indices: Optional[torch.Tensor],
adapter_data,
) -> torch.Tensor:
hidden_states = inputs_embeds
@ -495,8 +447,6 @@ class FlashLlamaModel(torch.nn.Module):
slots,
input_lengths,
max_s,
batch_lora_adapter_mask,
lora_indices,
adapter_data,
)
@ -506,7 +456,7 @@ class FlashLlamaModel(torch.nn.Module):
class FlashLlamaForCausalLM(torch.nn.Module):
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
def __init__(self, prefix, config, weights):
super().__init__()
self.embed_tokens = TensorParallelEmbedding(
@ -515,9 +465,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
),
weights=weights,
)
self.model = FlashLlamaModel(
prefix, config, weights, lora_weights, lora_configs
)
self.model = FlashLlamaModel(prefix, config, weights)
if config.tie_word_embeddings:
suffix = "model.embed_tokens"
else:
@ -544,8 +492,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s: int,
prefill_cache_indices: Optional[torch.Tensor] = None,
lm_head_indices: Optional[torch.Tensor] = None,
batch_lora_adapter_mask: Optional[List[str]] = None,
lora_indices: Optional[torch.Tensor] = None,
adapter_data: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
inputs_embeds = self.embed_tokens(input_ids)
@ -560,8 +506,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
max_s,
true_max_s=max_s,
prefill_cache_indices=prefill_cache_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
adapter_data=adapter_data,
)
if lm_head_indices is not None:

View File

@ -877,8 +877,6 @@ class FlashCausalLM(Model):
graph = torch.cuda.CUDAGraph()
self.cuda_graphs[bs]["graph"] = graph
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
torch.cuda.synchronize()
# Run once outside to warmup
self.model.forward(
@ -892,8 +890,6 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
torch.cuda.synchronize()
@ -909,8 +905,6 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=None,
lm_head_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
self.cuda_graphs[bs]["logits"] = logits
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
@ -1038,10 +1032,6 @@ class FlashCausalLM(Model):
# Dummy value, some models (starcoder2) don't accept `None`.
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
batch_lora_adapter_mask = torch.zeros(
seqlen, dtype=torch.bool, device=self.device
)
lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
self.model.forward(
@ -1057,8 +1047,6 @@ class FlashCausalLM(Model):
max_s=seqlen,
lm_head_indices=None,
prefill_cache_indices=None,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
)
def forward(
@ -1129,9 +1117,6 @@ class FlashCausalLM(Model):
else:
cuda_graph = None
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
if cu_seqlen_prefill is not None or cuda_graph is None:
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
@ -1144,8 +1129,6 @@ class FlashCausalLM(Model):
max_s=max_s,
prefill_cache_indices=batch.prefill_cache_indices,
lm_head_indices=lm_head_indices,
batch_lora_adapter_mask=batch_lora_adapter_mask,
lora_indices=lora_indices,
adapter_data=adapter_data,
)
if batch.prefill_cache_indices is not None:

View File

@ -99,9 +99,7 @@ class FlashLlama(FlashCausalLM):
weights._set_gptq_params(model_id, revision)
prefix = ""
model = FlashLlamaForCausalLM(
prefix, config, weights, lora_weights, lora_configs
)
model = FlashLlamaForCausalLM(prefix, config, weights)
torch.distributed.barrier(group=self.process_group)
super(FlashLlama, self).__init__(
model_id=model_id,