[WIP] Add support for Mistral-Nemo by supporting head_dim through config (#2254)
* Support passing head_dim through config * Using `head_dim` as a fallback is necessary since it's a non standard key in mistralConfig (as defined in transformers). * Shorter diff. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
9935720c87
commit
3961e32390
|
@ -149,15 +149,14 @@ class MistralAttention(torch.nn.Module):
|
|||
bias=False,
|
||||
)
|
||||
|
||||
head_size = config.hidden_size // config.num_attention_heads
|
||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||
query_key_value,
|
||||
layer_id,
|
||||
["q_proj", "k_proj", "v_proj"],
|
||||
sizes=[
|
||||
head_size * config.num_attention_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
head_size * config.num_key_value_heads,
|
||||
self.head_size * config.num_attention_heads,
|
||||
self.head_size * config.num_key_value_heads,
|
||||
self.head_size * config.num_key_value_heads,
|
||||
],
|
||||
process_group=weights.process_group,
|
||||
)
|
||||
|
|
Loading…
Reference in New Issue