feat: load weights within layer and refactor lora pass

This commit is contained in:
drbh 2024-06-04 01:38:43 +00:00
parent db3d8e6518
commit 0a6ea7fb57
3 changed files with 109 additions and 70 deletions

View File

@ -126,27 +126,51 @@ class FlashLlamaAttention(torch.nn.Module):
self.query_key_value = load_attention(config, prefix, weights)
self.index = index
self.adapter_weights = {}
for adapter_id, adapter_weights in all_adapter_weights.items():
filtered_keys = list(
filter(
lambda x: x.startswith(
f"base_model.model.model.layers.{index}.self_attn"
),
adapter_weights.keys(),
)
)
self.adapter_weights[adapter_id] = {
key: torch.tensor(
adapter_weights[key],
adapter_names = list(all_adapter_weights.keys())
self.lora_a_matrix = torch.empty(
(len(adapter_names), 2, 4096, 8),
device=weights.device,
dtype=weights.dtype,
).T
for key in filtered_keys
}
)
self.lora_b_matrix = torch.empty(
(len(adapter_names), 2, 8, 4096),
device=weights.device,
dtype=weights.dtype,
)
self.index_to_key = {
i: key for i, key in enumerate(self.adapter_weights.keys())
}
self.pre_multiplied_lora_matrix = torch.empty(
(len(adapter_names), 2, 4096, 4096),
device=weights.device,
dtype=weights.dtype,
)
self.key_to_index = {}
self.index_to_key = {}
lora_prefix = f"base_model.model.model.layers.{index}.self_attn"
for adapter_index, adapter_name in enumerate(adapter_names):
self.lora_alpha = 16.0
self.lora_r = 8.0
self.lora_scale = self.lora_alpha / self.lora_r
self.key_to_index[adapter_name] = adapter_index
self.index_to_key[adapter_index] = adapter_name
adapter_weights = all_adapter_weights[adapter_name]
for target_index, target in enumerate(["q", "v"]):
adapter_weight_a = adapter_weights.get_tensor(
f"{lora_prefix}.{target}_proj.lora_A.weight"
)
adapter_weight_b = adapter_weights.get_tensor(
f"{lora_prefix}.{target}_proj.lora_B.weight"
)
pre_multiplied_lora_matrix = torch.matmul(
adapter_weight_a.T * self.lora_scale,
adapter_weight_b.T,
).contiguous()
self.pre_multiplied_lora_matrix[adapter_index, target_index, :, :] = (
pre_multiplied_lora_matrix
)
self.o_proj = TensorParallelRowLinear.load(
config,
@ -159,23 +183,6 @@ class FlashLlamaAttention(torch.nn.Module):
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)
def get_adapter_weights(self, lora_index):
adapter_id = self.index_to_key[lora_index]
q_proj_lora_a = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_A.weight"
]
q_proj_lora_b = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.q_proj.lora_B.weight"
]
v_proj_lora_a = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_A.weight"
]
v_proj_lora_b = self.adapter_weights[adapter_id][
f"base_model.model.model.layers.{self.index}.self_attn.v_proj.lora_B.weight"
]
return q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b
def forward(
self,
hidden_states,
@ -201,39 +208,42 @@ class FlashLlamaAttention(torch.nn.Module):
query = query.view(-1, self.num_heads, self.head_size)
kv = kv.view(-1, 2, self.num_key_value_heads, self.head_size)
q_proj_lora_a, q_proj_lora_b, v_proj_lora_a, v_proj_lora_b = (
self.get_adapter_weights(
# TODO: dont just assume the first adapter
lora_indices[0].item()
)
)
query_adapted = torch.matmul(
hidden_states,
torch.matmul(
q_proj_lora_a,
q_proj_lora_b,
),
)
value_adapted = torch.matmul(
hidden_states,
torch.matmul(
v_proj_lora_a,
v_proj_lora_b,
),
)
batch_size = query.size(0)
if not torch.all(lora_indices, -1):
lora_mask = lora_indices[lora_indices != -1]
# TODO: improve this to avoid unnecessary work
# mask across batch and within lora adapters
query[batch_lora_adapter_mask] += query_adapted.view(
batch_size, self.num_heads, self.head_size
)[batch_lora_adapter_mask]
kv[batch_lora_adapter_mask, 1] += value_adapted.view(
batch_size, self.num_key_value_heads, self.head_size
)[batch_lora_adapter_mask]
q_pre_multiplied_batch = torch.ones(
(batch_size, 4096, 4096),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
q_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
lora_mask, 0
]
v_pre_multiplied_batch = torch.ones(
(batch_size, 4096, 4096),
device=hidden_states.device,
dtype=hidden_states.dtype,
)
v_pre_multiplied_batch[lora_mask] = self.pre_multiplied_lora_matrix[
lora_mask, 1
]
query_adapted = (
torch.bmm(hidden_states.unsqueeze(1), q_pre_multiplied_batch)
.squeeze(1)
.view(batch_size, self.num_heads, self.head_size)
)
value_adapted = (
torch.bmm(hidden_states.unsqueeze(1), v_pre_multiplied_batch)
.squeeze(1)
.view(batch_size, self.num_key_value_heads, self.head_size)
)
query[batch_lora_adapter_mask] += query_adapted[batch_lora_adapter_mask]
kv[batch_lora_adapter_mask, 1] += value_adapted[batch_lora_adapter_mask]
self.rotary_emb(query, torch.select(kv, dim=1, index=0), cos, sin)
@ -503,6 +513,9 @@ class FlashLlamaForCausalLM(torch.nn.Module):
weights=weights,
)
def get_lora_index(self, adapter_id):
return self.model.layers[0].self_attn.key_to_index[adapter_id]
def forward(
self,
input_ids: torch.Tensor,

View File

@ -1064,11 +1064,11 @@ class FlashCausalLM(Model):
cuda_graph = None
batch_lora_adapter_mask = torch.zeros(bs, dtype=torch.bool, device=self.device)
lora_indices = torch.zeros(bs, dtype=torch.int32, device=self.device)
lora_indices = torch.full((bs,), -1, dtype=torch.int32, device=self.device)
for i, r in enumerate(batch.requests):
if r.adapter_id:
lora_index = int(r.adapter_id)
lora_index = self.model.get_lora_index(r.adapter_id)
lora_indices[i] = lora_index
batch_lora_adapter_mask[i] = True

View File

@ -18,6 +18,17 @@ WEIGHTS_CACHE_OVERRIDE = os.getenv("WEIGHTS_CACHE_OVERRIDE", None)
HF_HUB_OFFLINE = os.environ.get("HF_HUB_OFFLINE", "0").lower() in ["true", "1", "yes"]
def _cached_adapter_weight_files(
adapter_id: str, revision: Optional[str], extension: str
) -> List[str]:
"""Guess weight files from the cached revision snapshot directory"""
d = _get_cached_revision_directory(adapter_id, revision)
if not d:
return []
filenames = _adapter_weight_files_from_dir(d, extension)
return filenames
def _cached_weight_files(
model_id: str, revision: Optional[str], extension: str
) -> List[str]:
@ -60,6 +71,21 @@ def _weight_files_from_dir(d: Path, extension: str) -> List[str]:
return filenames
def _adapter_weight_files_from_dir(d: Path, extension: str) -> List[str]:
# os.walk: do not iterate, just scan for depth 1, not recursively
# see _weight_files_from_dir, that's also what is done there
root, _, files = next(os.walk(str(d)))
filenames = [
os.path.join(root, f)
for f in files
if f.endswith(extension)
and "arguments" not in f
and "args" not in f
and "training" not in f
]
return filenames
def _get_cached_revision_directory(
model_id: str, revision: Optional[str]
) -> Optional[Path]: