fix: prefer adapter_data and refactors
This commit is contained in:
parent
8b50f4b779
commit
d5f21d57d1
|
@ -117,8 +117,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
prefix: str,
|
||||
config,
|
||||
weights,
|
||||
lora_weights,
|
||||
lora_configs,
|
||||
):
|
||||
super().__init__()
|
||||
self.num_heads = config.num_attention_heads
|
||||
|
@ -151,40 +149,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
|
||||
self.query_key_value = load_attention(config, prefix, weights, index)
|
||||
self.index = index
|
||||
self.adapter_weights = {}
|
||||
adapter_names = list(lora_weights.keys())
|
||||
|
||||
self.n_loras = len(adapter_names)
|
||||
self.pre_multiplied_lora_matrices = torch.empty(
|
||||
(self.n_loras, 2, self.hidden_size, self.hidden_size),
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
|
||||
self.key_to_index = {}
|
||||
|
||||
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
|
||||
for adapter_index, adapter_name in enumerate(adapter_names):
|
||||
self.lora_alpha = lora_configs[adapter_name].lora_alpha
|
||||
self.lora_r = lora_configs[adapter_name].r
|
||||
self.lora_scale = self.lora_alpha / self.lora_r
|
||||
self.key_to_index[adapter_name] = adapter_index
|
||||
adapter_weights = lora_weights[adapter_name]
|
||||
for target_index, target in enumerate(["q", "v"]):
|
||||
adapter_weight_a = adapter_weights.get_tensor(
|
||||
f"{lora_prefix}.{target}_proj.lora_A.weight"
|
||||
)
|
||||
adapter_weight_b = adapter_weights.get_tensor(
|
||||
f"{lora_prefix}.{target}_proj.lora_B.weight"
|
||||
)
|
||||
pre_multiplied_lora_matrix = torch.matmul(
|
||||
adapter_weight_a.T * self.lora_scale,
|
||||
adapter_weight_b.T,
|
||||
).contiguous()
|
||||
|
||||
self.pre_multiplied_lora_matrices[adapter_index, target_index, :, :] = (
|
||||
pre_multiplied_lora_matrix
|
||||
)
|
||||
|
||||
o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
|
@ -216,8 +180,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
):
|
||||
qkv = self.query_key_value(hidden_states, adapter_data)
|
||||
|
@ -355,15 +317,13 @@ class LlamaMLP(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaLayer(nn.Module):
|
||||
def __init__(self, index, prefix, config, weights, lora_weights, lora_configs):
|
||||
def __init__(self, index, prefix, config, weights):
|
||||
super().__init__()
|
||||
self.self_attn = FlashLlamaAttention(
|
||||
index=index,
|
||||
prefix=f"{prefix}.self_attn",
|
||||
config=config,
|
||||
weights=weights,
|
||||
lora_weights=lora_weights,
|
||||
lora_configs=lora_configs,
|
||||
)
|
||||
self.mlp = LlamaMLP(
|
||||
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
|
||||
|
@ -390,8 +350,6 @@ class FlashLlamaLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
):
|
||||
normed_hidden_states, res = self.input_layernorm(hidden_states, residual)
|
||||
|
@ -407,8 +365,6 @@ class FlashLlamaLayer(nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
|
@ -423,7 +379,7 @@ class FlashLlamaLayer(nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaModel(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
process_group = weights.process_group
|
||||
|
@ -440,8 +396,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
),
|
||||
config=config,
|
||||
weights=weights,
|
||||
lora_weights=lora_weights,
|
||||
lora_configs=lora_configs,
|
||||
)
|
||||
for layer_id in range(config.num_hidden_layers)
|
||||
]
|
||||
|
@ -470,8 +424,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
max_s: int,
|
||||
true_max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor],
|
||||
batch_lora_adapter_mask: Optional[List[str]],
|
||||
lora_indices: Optional[torch.Tensor],
|
||||
adapter_data,
|
||||
) -> torch.Tensor:
|
||||
hidden_states = inputs_embeds
|
||||
|
@ -495,8 +447,6 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
slots,
|
||||
input_lengths,
|
||||
max_s,
|
||||
batch_lora_adapter_mask,
|
||||
lora_indices,
|
||||
adapter_data,
|
||||
)
|
||||
|
||||
|
@ -506,7 +456,7 @@ class FlashLlamaModel(torch.nn.Module):
|
|||
|
||||
|
||||
class FlashLlamaForCausalLM(torch.nn.Module):
|
||||
def __init__(self, prefix, config, weights, lora_weights, lora_configs):
|
||||
def __init__(self, prefix, config, weights):
|
||||
super().__init__()
|
||||
|
||||
self.embed_tokens = TensorParallelEmbedding(
|
||||
|
@ -515,9 +465,7 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
),
|
||||
weights=weights,
|
||||
)
|
||||
self.model = FlashLlamaModel(
|
||||
prefix, config, weights, lora_weights, lora_configs
|
||||
)
|
||||
self.model = FlashLlamaModel(prefix, config, weights)
|
||||
if config.tie_word_embeddings:
|
||||
suffix = "model.embed_tokens"
|
||||
else:
|
||||
|
@ -544,8 +492,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
max_s: int,
|
||||
prefill_cache_indices: Optional[torch.Tensor] = None,
|
||||
lm_head_indices: Optional[torch.Tensor] = None,
|
||||
batch_lora_adapter_mask: Optional[List[str]] = None,
|
||||
lora_indices: Optional[torch.Tensor] = None,
|
||||
adapter_data: Optional[torch.Tensor] = None,
|
||||
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
||||
inputs_embeds = self.embed_tokens(input_ids)
|
||||
|
@ -560,8 +506,6 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
max_s,
|
||||
true_max_s=max_s,
|
||||
prefill_cache_indices=prefill_cache_indices,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if lm_head_indices is not None:
|
||||
|
|
|
@ -877,8 +877,6 @@ class FlashCausalLM(Model):
|
|||
graph = torch.cuda.CUDAGraph()
|
||||
self.cuda_graphs[bs]["graph"] = graph
|
||||
|
||||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
torch.cuda.synchronize()
|
||||
# Run once outside to warmup
|
||||
self.model.forward(
|
||||
|
@ -892,8 +890,6 @@ class FlashCausalLM(Model):
|
|||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
torch.cuda.synchronize()
|
||||
|
||||
|
@ -909,8 +905,6 @@ class FlashCausalLM(Model):
|
|||
max_s=max_s,
|
||||
prefill_cache_indices=None,
|
||||
lm_head_indices=None,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
self.cuda_graphs[bs]["logits"] = logits
|
||||
self.cuda_graphs[bs]["speculative_logits"] = speculative_logits
|
||||
|
@ -1038,10 +1032,6 @@ class FlashCausalLM(Model):
|
|||
|
||||
# Dummy value, some models (starcoder2) don't accept `None`.
|
||||
input_lengths = torch.ones(seqlen, dtype=torch.int32, device=self.device)
|
||||
batch_lora_adapter_mask = torch.zeros(
|
||||
seqlen, dtype=torch.bool, device=self.device
|
||||
)
|
||||
lora_indices = torch.zeros(seqlen, dtype=torch.int32, device=self.device)
|
||||
|
||||
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
|
||||
self.model.forward(
|
||||
|
@ -1057,8 +1047,6 @@ class FlashCausalLM(Model):
|
|||
max_s=seqlen,
|
||||
lm_head_indices=None,
|
||||
prefill_cache_indices=None,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
)
|
||||
|
||||
def forward(
|
||||
|
@ -1129,9 +1117,6 @@ class FlashCausalLM(Model):
|
|||
else:
|
||||
cuda_graph = None
|
||||
|
||||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
|
||||
|
||||
if cu_seqlen_prefill is not None or cuda_graph is None:
|
||||
logits, speculative_logits = self.model.forward(
|
||||
input_ids=input_ids,
|
||||
|
@ -1144,8 +1129,6 @@ class FlashCausalLM(Model):
|
|||
max_s=max_s,
|
||||
prefill_cache_indices=batch.prefill_cache_indices,
|
||||
lm_head_indices=lm_head_indices,
|
||||
batch_lora_adapter_mask=batch_lora_adapter_mask,
|
||||
lora_indices=lora_indices,
|
||||
adapter_data=adapter_data,
|
||||
)
|
||||
if batch.prefill_cache_indices is not None:
|
||||
|
|
|
@ -99,9 +99,7 @@ class FlashLlama(FlashCausalLM):
|
|||
weights._set_gptq_params(model_id, revision)
|
||||
|
||||
prefix = ""
|
||||
model = FlashLlamaForCausalLM(
|
||||
prefix, config, weights, lora_weights, lora_configs
|
||||
)
|
||||
model = FlashLlamaForCausalLM(prefix, config, weights)
|
||||
torch.distributed.barrier(group=self.process_group)
|
||||
super(FlashLlama, self).__init__(
|
||||
model_id=model_id,
|
||||
|
|
Loading…
Reference in New Issue