diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index f6ffb4aa..1cd13a2a 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -762,8 +762,6 @@ def get_model( default_dtype=torch.bfloat16, trust_remote_code=trust_remote_code, lora_adapter_ids=lora_adapter_ids, - # hidden_size / num_attention_heads is wrong in `google/gemma-2-9b-it` - head_size=config_dict["head_dim"], ) elif sharded: raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Gemma2")) diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 8028dbe8..eeb3c45f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -117,7 +117,10 @@ class MistralAttention(torch.nn.Module): ) self.num_heads = config.num_attention_heads self.hidden_size = config.hidden_size - self.head_size = self.hidden_size // self.num_heads + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = self.hidden_size // self.num_heads self.rotary_emb = PositionRotaryEmbedding.static( config=config, diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index cfffafa1..5db62431 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -925,7 +925,12 @@ class FlashCausalLM(Model): assert self.num_kv_heads > 0 if head_size is None: - self.head_size = config.hidden_size // config.num_attention_heads + # Some models use GQA and different sizes for o_proj + # and q_proj, that allows for that. + if hasattr(config, "head_dim"): + self.head_size = config.head_dim + else: + self.head_size = config.hidden_size // config.num_attention_heads else: self.head_size = head_size