feat: load weights within layer and refactor lora pass
This commit is contained in:
parent
db3d8e6518
commit
0a6ea7fb57
|
@ -126,27 +126,51 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
self.query_key_value = load_attention(config, prefix, weights)
|
||||
self.index = index
|
||||
self.adapter_weights = {}
|
||||
for adapter_id, adapter_weights in all_adapter_weights.items():
|
||||
filtered_keys = list(
|
||||
filter(
|
||||
lambda x: x.startswith(
|
||||
f"base_model.model.model.layers.{index}.self_attn"
|
||||
),
|
||||
adapter_weights.keys(),
|
||||
)
|
||||
)
|
||||
self.adapter_weights[adapter_id] = {
|
||||
key: torch.tensor(
|
||||
adapter_weights[key],
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
).T
|
||||
for key in filtered_keys
|
||||
}
|
||||
adapter_names = list(all_adapter_weights.keys())
|
||||
|
||||
self.index_to_key = {
|
||||
i: key for i, key in enumerate(self.adapter_weights.keys())
|
||||
}
|
||||
self.lora_a_matrix = torch.empty(
|
||||
(len(adapter_names), 2, 4096, 8),
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
self.lora_b_matrix = torch.empty(
|
||||
(len(adapter_names), 2, 8, 4096),
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
|
||||
self.pre_multiplied_lora_matrix = torch.empty(
|
||||
(len(adapter_names), 2, 4096, 4096),
|
||||
device=weights.device,
|
||||
dtype=weights.dtype,
|
||||
)
|
||||
|
||||
self.key_to_index = {}
|
||||
self.index_to_key = {}
|
||||
|
||||
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
|
||||
for adapter_index, adapter_name in enumerate(adapter_names):
|
||||
self.lora_alpha = 16.0
|
||||
self.lora_r = 8.0
|
||||
self.lora_scale = self.lora_alpha / self.lora_r
|
||||
self.key_to_index[adapter_name] = adapter_index
|
||||
self.index_to_key[adapter_index] = adapter_name
|
||||
adapter_weights = all_adapter_weights[adapter_name]
|
||||
for target_index, target in enumerate(["q", "v"]):
|
||||
adapter_weight_a = adapter_weights.get_tensor(
|
||||
f"{lora_prefix}.{target}_proj.lora_A.weight"
|
||||
)
|
||||
adapter_weight_b = adapter_weights.get_tensor(
|
||||
f"{lora_prefix}.{target}_proj.lora_B.weight"
|
||||
)
|
||||
pre_multiplied_lora_matrix = torch.matmul(
|
||||
adapter_weight_a.T * self.lora_scale,
|
||||
adapter_weight_b.T,
|
||||
).contiguous()
|
||||
|
||||
self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = (
|
||||
pre_multiplied_lora_matrix
|
||||
)
|
||||
|
||||
self.o_proj = TensorParallelRowLinear.load(
|
||||
config,
|
||||
|
@ -159,23 +183,6 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
|
||||
).repeat_interleave(self.num_groups)
|
||||
|
||||
def get_adapter_weights(self, lora_index):
|
||||
adapter_id = self.index_to_key[lora_index]
|
||||
q_proj_lora_a = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight"
|
||||
]
|
||||
q_proj_lora_b = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight"
|
||||
]
|
||||
|
||||
v_proj_lora_a = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight"
|
||||
]
|
||||
v_proj_lora_b = self.adapter_weights[adapter_id][
|
||||
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight"
|
||||
]
|
||||
return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -201,39 +208,42 @@ class FlashLlamaAttention(torch.nn.Module):
|
|||
query = query.view(-1, self.num_heads, self.head_size)
|
||||
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
|
||||
|
||||
q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = (
|
||||
self.get_adapter_weights(
|
||||
# TODO: dont just assume the first adapter
|
||||
lora_indices[0].item()
|
||||
)
|
||||
)
|
||||
|
||||
query_adapted = torch.matmul(
|
||||
hidden_states,
|
||||
torch.matmul(
|
||||
q_proj_lora_a,
|
||||
q_proj_lora_b,
|
||||
),
|
||||
)
|
||||
|
||||
value_adapted = torch.matmul(
|
||||
hidden_states,
|
||||
torch.matmul(
|
||||
v_proj_lora_a,
|
||||
v_proj_lora_b,
|
||||
),
|
||||
)
|
||||
|
||||
batch_size = query.size(0)
|
||||
if not torch.all(lora_indices, -1):
|
||||
lora_mask = lora_indices[lora_indices != -1]
|
||||
|
||||
# TODO: improve this to avoid unnecessary work
|
||||
# mask across batch and within lora adapters
|
||||
query[batch_lora_adapter_mask] += query_adapted.view(
|
||||
batch_size, self.num_heads, self.head_size
|
||||
)[batch_lora_adapter_mask]
|
||||
kv[batch_lora_adapter_mask, 1] += value_adapted.view(
|
||||
batch_size, self.num_key_value_heads, self.head_size
|
||||
)[batch_lora_adapter_mask]
|
||||
q_pre_multiplied_batch = torch.ones(
|
||||
(batch_size, 4096, 4096),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
q_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
|
||||
lora_mask, 0
|
||||
]
|
||||
|
||||
v_pre_multiplied_batch = torch.ones(
|
||||
(batch_size, 4096, 4096),
|
||||
device=hidden_states.device,
|
||||
dtype=hidden_states.dtype,
|
||||
)
|
||||
|
||||
v_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
|
||||
lora_mask, 1
|
||||
]
|
||||
|
||||
query_adapted = (
|
||||
torch.bmm(hidden_states.unsqueeze(1), q_pre_multiplied_batch)
|
||||
.squeeze(1)
|
||||
.view(batch_size, self.num_heads, self.head_size)
|
||||
)
|
||||
value_adapted = (
|
||||
torch.bmm(hidden_states.unsqueeze(1), v_pre_multiplied_batch)
|
||||
.squeeze(1)
|
||||
.view(batch_size, self.num_key_value_heads, self.head_size)
|
||||
)
|
||||
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
|
||||
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
|
||||
|
||||
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
|
||||
|
||||
|
@ -503,6 +513,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
|
|||
weights=weights,
|
||||
)
|
||||
|
||||
def get_lora_index(self, adapter_id):
|
||||
return self.model.layers[0].self_attn.key_to_index[adapter_id]
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
|
|
|
@ -1064,11 +1064,11 @@ class FlashCausalLM(Model):
|
|||
cuda_graph = None
|
||||
|
||||
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
|
||||
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
|
||||
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
|
||||
|
||||
for i, r in enumerate(batch.requests):
|
||||
if r.adapter_id:
|
||||
lora_index = int(r.adapter_id)
|
||||
lora_index = self.model.get_lora_index(r.adapter_id)
|
||||
lora_indices[i] = lora_index
|
||||
batch_lora_adapter_mask[i] = True
|
||||
|
||||
|
|
|
@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
|
|||
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
|
||||
|
||||
|
||||
def _cached_adapter_weight_files(
|
||||
adapter_id: str, revision: Optional[str], extension: str
|
||||
) -> List[str]:
|
||||
"""Guess weight files from the cached revision snapshot directory"""
|
||||
d = _get_cached_revision_directory(adapter_id, revision)
|
||||
if not d:
|
||||
return []
|
||||
filenames = _adapter_weight_files_from_dir(d, extension)
|
||||
return filenames
|
||||
|
||||
|
||||
def _cached_weight_files(
|
||||
model_id: str, revision: Optional[str], extension: str
|
||||
) -> List[str]:
|
||||
|
@ -60,6 +71,21 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
|||
return filenames
|
||||
|
||||
|
||||
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
|
||||
# os.walk: do not iterate, just scan for depth 1, not recursively
|
||||
# see _weight_files_from_dir, that's also what is done there
|
||||
root, _, files = next(os.walk(str(d)))
|
||||
filenames = [
|
||||
os.path.join(root, f)
|
||||
for f in files
|
||||
if f.endswith(extension)
|
||||
and "arguments" not in f
|
||||
and "args" not in f
|
||||
and "training" not in f
|
||||
]
|
||||
return filenames
|
||||
|
||||
|
||||
def _get_cached_revision_directory(
|
||||
model_id: str, revision: Optional[str]
|
||||
) -> Optional[Path]:
|
||||
|
|
Loading…
Reference in New Issue