[WIP] Add support for Mistral-Nemo by supporting head_dim through config (#2254)
* Support passing head_dim through config * Using `head_dim` as a fallback is necessary since it's a non standard key in mistralConfig (as defined in transformers). * Shorter diff. --------- Co-authored-by: Nicolas Patry <patry.nicolas@protonmail.com>
This commit is contained in:
parent
9935720c87
commit
3961e32390
|
@ -149,15 +149,14 @@ class MistralAttention(torch.nn.Module):
|
||||||
bias=False,
|
bias=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
head_size = config.hidden_size // config.num_attention_heads
|
|
||||||
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
self.query_key_value = TensorParallelMultiAdapterLinear.load(
|
||||||
query_key_value,
|
query_key_value,
|
||||||
layer_id,
|
layer_id,
|
||||||
["q_proj", "k_proj", "v_proj"],
|
["q_proj", "k_proj", "v_proj"],
|
||||||
sizes=[
|
sizes=[
|
||||||
head_size * config.num_attention_heads,
|
self.head_size * config.num_attention_heads,
|
||||||
head_size * config.num_key_value_heads,
|
self.head_size * config.num_key_value_heads,
|
||||||
head_size * config.num_key_value_heads,
|
self.head_size * config.num_key_value_heads,
|
||||||
],
|
],
|
||||||
process_group=weights.process_group,
|
process_group=weights.process_group,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in New Issue