feat: adjust attn weight loading logic (#1975)

This PR updates `load_attention` to prefer loading specific attention
based on the model type. Additionally there were two cases where
`TensorParallelColumnLinear.load_multi` was called and this reduces it
to a single path
This commit is contained in:
drbh 2024-05-29 12:42:11 -04:00 committed by GitHub
parent 612bc483b6
commit cbced7f0f9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 21 additions and 27 deletions

View File

@ -49,37 +49,31 @@ if SYSTEM == "rocm":
def load_attention(config, prefix, weights):
bias = config.attention_bias
if config.num_attention_heads != config.num_key_value_heads:
return TensorParallelColumnLinear.load_multi(
# if specific model type, load the correct attention
if config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
)
else:
if config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=bias,
)
elif config.model_type == "phi3":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.qkv_proj",
weights=weights,
bias=bias,
)
else:
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
elif config.model_type == "baichuan":
return TensorParallelColumnLinear.load_qkv(
config,
prefix=f"{prefix}.W_pack",
weights=weights,
bias=bias,
)
# otherwise, load the default attention based on the number of heads
return TensorParallelColumnLinear.load_multi(
config,
prefixes=[f"{prefix}.q_proj", f"{prefix}.k_proj", f"{prefix}.v_proj"],
dim=0,
weights=weights,
bias=bias,
)
class FlashLlamaAttention(torch.nn.Module):