[MPS] call contiguous after permute (#1411)

* call contiguous after permute

Fixes for MPS device

* Fix MPS UserWarning

* make style

* Revert "Fix MPS UserWarning"

This reverts commit b46c32810ee5fdc4c16a8e9224a826490b66cf49.
This commit is contained in:
Kashif Rasul 2022-11-25 13:59:56 +01:00 committed by GitHub
parent 35099b207e
commit babfb8a020
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 6 additions and 2 deletions

View File

@ -221,11 +221,15 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
# 3. Output
if self.is_input_continuous:
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2)
hidden_states = (
hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
)
output = hidden_states + residual
elif self.is_input_vectorized: