float16 dep

This commit is contained in:
Mohit Sharma 2024-09-27 15:53:44 +00:00
parent b2cd1b66ed
commit 7cb49f6f4f
1 changed files with 1 additions and 0 deletions

View File

@ -303,6 +303,7 @@ class MistralMLP(nn.Module):
if (
SYSTEM == "rocm"
and self.hidden_act == "silu"
and hidden_states.dtype == torch.float16
and hidden_states.shape[0] == 1
and not self.quantize
):