[MPS] call contiguous after permute (#1411)
* call contiguous after permute Fixes for MPS device * Fix MPS UserWarning * make style * Revert "Fix MPS UserWarning" This reverts commit b46c32810ee5fdc4c16a8e9224a826490b66cf49.
This commit is contained in:
parent
35099b207e
commit
babfb8a020
|
@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
||||||
# 3. Output
|
# 3. Output
|
||||||
if self.is_input_continuous:
|
if self.is_input_continuous:
|
||||||
if not self.use_linear_projection:
|
if not self.use_linear_projection:
|
||||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
hidden_states = (
|
||||||
|
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
)
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
else:
|
else:
|
||||||
hidden_states = self.proj_out(hidden_states)
|
hidden_states = self.proj_out(hidden_states)
|
||||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
hidden_states = (
|
||||||
|
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||||
|
)
|
||||||
|
|
||||||
output = hidden_states + residual
|
output = hidden_states + residual
|
||||||
elif self.is_input_vectorized:
|
elif self.is_input_vectorized:
|
||||||
|
|
Loading…
Reference in New Issue