float16 dep
This commit is contained in:
parent
b2cd1b66ed
commit
7cb49f6f4f
|
@ -303,6 +303,7 @@ class MistralMLP(nn.Module):
|
||||||
if (
|
if (
|
||||||
SYSTEM == "rocm"
|
SYSTEM == "rocm"
|
||||||
and self.hidden_act == "silu"
|
and self.hidden_act == "silu"
|
||||||
|
and hidden_states.dtype == torch.float16
|
||||||
and hidden_states.shape[0] == 1
|
and hidden_states.shape[0] == 1
|
||||||
and not self.quantize
|
and not self.quantize
|
||||||
):
|
):
|
||||||
|
|
Loading…
Reference in New Issue