mps changes for PyTorch 1.13 (#926)

* Docs: refer to pre-RC version of PyTorch 1.13.0.

* Remove temporary workaround for unavailable op.

* Update comment to make it less ambiguous.

* Remove use of contiguous in mps.

It appears to not longer be necessary.

* Special case: use einsum for much better performance in mps

* Update mps docs.

* Minor doc update.

* Accept suggestion

Co-authored-by: Anton Lozhkov <anton@huggingface.co>

Co-authored-by: Anton Lozhkov <anton@huggingface.co>
This commit is contained in:
Pedro Cuenca 2022-10-25 16:41:51 +02:00 committed by GitHub
parent 28b134e627
commit 3d02c92187
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 48 additions and 26 deletions

View File

@ -17,9 +17,13 @@ specific language governing permissions and limitations under the License.
## Requirements
- Mac computer with Apple silicon (M1/M2) hardware.
- macOS 12.3 or later.
- macOS 12.6 or later (13.0 or later recommended).
- arm64 version of Python.
- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later.
- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using:
```
pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu
```
## Inference Pipeline
@ -34,6 +38,9 @@ from diffusers import StableDiffusionPipeline
pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
pipe = pipe.to("mps")
# Recommended if your computer has < 64 GB of RAM
pipe.enable_attention_slicing()
prompt = "a photo of an astronaut riding a horse on mars"
# First-time "warmup" pass (see explanation above)
@ -43,16 +50,17 @@ _ = pipe(prompt, num_inference_steps=1)
image = pipe(prompt).images[0]
```
## Performance Recommendations
M1/M2 performance is very sensitive to memory pressure. The system will automatically swap if it needs to, but performance will degrade significantly when it does.
We recommend you use _attention slicing_ to reduce memory pressure during inference and prevent swapping, particularly if your computer has lass than 64 GB of system RAM, or if you generate images at non-standard resolutions larger than 512 × 512 pixels. Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually has a performance impact of ~20% in computers without universal memory, but we have observed _better performance_ in most Apple Silicon computers, unless you have 64 GB or more.
```python
pipeline.enable_attention_slicing()
```
## Known Issues
- As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372).
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching.
## Performance
These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5.
| Device | Steps | Time |
|--------|-------|---------|
| CPU | 50 | 213.46s |
| MPS | 50 | 30.81s |
- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching.

View File

@ -249,7 +249,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline):
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)

View File

@ -324,7 +324,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline):
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)

View File

@ -207,7 +207,6 @@ class BasicTransformerBlock(nn.Module):
self.attn2._slice_size = slice_size
def forward(self, hidden_states, context=None):
hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states
hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states
hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
@ -288,10 +287,19 @@ class CrossAttention(nn.Module):
def _attention(self, query, key, value):
# TODO: use baddbmm for better performance
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale
else:
attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale
attention_probs = attention_scores.softmax(dim=-1)
# compute attention output
hidden_states = torch.matmul(attention_probs, value)
if query.device.type == "mps":
hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value)
else:
hidden_states = torch.matmul(attention_probs, value)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
return hidden_states
@ -305,11 +313,21 @@ class CrossAttention(nn.Module):
for i in range(hidden_states.shape[0] // slice_size):
start_idx = i * slice_size
end_idx = (i + 1) * slice_size
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
if query.device.type == "mps":
# Better performance on mps (~20-25%)
attn_slice = (
torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx])
* self.scale
)
else:
attn_slice = (
torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale
) # TODO: use baddbmm for better performance
attn_slice = attn_slice.softmax(dim=-1)
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
if query.device.type == "mps":
attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx])
else:
attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx])
hidden_states[start_idx:end_idx] = attn_slice

View File

@ -492,10 +492,6 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)):
kernel_h, kernel_w = kernel.shape
out = tensor.view(-1, in_h, 1, in_w, 1, minor)
# Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535
if tensor.device.type == "mps":
out = out.to("cpu")
out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
out = out.view(-1, in_h * up_y, in_w * up_x, minor)

View File

@ -287,7 +287,7 @@ class StableDiffusionPipeline(DiffusionPipeline):
latents_dtype = text_embeddings.dtype
if latents is None:
if self.device.type == "mps":
# randn does not exist on mps
# randn does not work reproducibly on mps
latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to(
self.device
)