diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 3e16d371..341a2352 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -303,6 +303,7 @@ class MistralMLP(nn.Module): if ( SYSTEM == "rocm" and self.hidden_act == "silu" + and hidden_states.dtype == torch.float16 and hidden_states.shape[0] == 1 and not self.quantize ):