[MPS] call contiguous after permute (#1411)
* call contiguous after permute Fixes for MPS device * Fix MPS UserWarning * make style * Revert "Fix MPS UserWarning" This reverts commit b46c32810ee5fdc4c16a8e9224a826490b66cf49.
This commit is contained in:
parent
35099b207e
commit
babfb8a020
|
@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||
# 3. Output
|
||||
if self.is_input_continuous:
|
||||
if not self.use_linear_projection:
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
else:
|
||||
hidden_states = self.proj_out(hidden_states)
|
||||
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
|
||||
hidden_states = (
|
||||
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
|
||||
)
|
||||
|
||||
output = hidden_states + residual
|
||||
elif self.is_input_vectorized:
|
||||
|
|
Loading…
Reference in New Issue