diff --git a/docs/source/optimization/mps.mdx b/docs/source/optimization/mps.mdx index 754adae5..4eeabc65 100644 --- a/docs/source/optimization/mps.mdx +++ b/docs/source/optimization/mps.mdx @@ -17,9 +17,13 @@ specific language governing permissions and limitations under the License. ## Requirements - Mac computer with Apple silicon (M1/M2) hardware. -- macOS 12.3 or later. +- macOS 12.6 or later (13.0 or later recommended). - arm64 version of Python. -- PyTorch [Preview (Nightly)](https://pytorch.org/get-started/locally/), version `1.14.0.dev20221007` or later. +- PyTorch 1.13.0 RC (Release Candidate). You can install it with `pip` using: + +``` +pip3 install --pre torch --extra-index-url https://download.pytorch.org/whl/test/cpu +``` ## Inference Pipeline @@ -34,6 +38,9 @@ from diffusers import StableDiffusionPipeline pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5") pipe = pipe.to("mps") +# Recommended if your computer has < 64 GB of RAM +pipe.enable_attention_slicing() + prompt = "a photo of an astronaut riding a horse on mars" # First-time "warmup" pass (see explanation above) @@ -43,16 +50,17 @@ _ = pipe(prompt, num_inference_steps=1) image = pipe(prompt).images[0] ``` +## Performance Recommendations + +M1/M2 performance is very sensitive to memory pressure. The system will automatically swap if it needs to, but performance will degrade significantly when it does. + +We recommend you use _attention slicing_ to reduce memory pressure during inference and prevent swapping, particularly if your computer has lass than 64 GB of system RAM, or if you generate images at non-standard resolutions larger than 512 × 512 pixels. Attention slicing performs the costly attention operation in multiple steps instead of all at once. It usually has a performance impact of ~20% in computers without universal memory, but we have observed _better performance_ in most Apple Silicon computers, unless you have 64 GB or more. + +```python +pipeline.enable_attention_slicing() +``` + ## Known Issues - As mentioned above, we are investigating a strange [first-time inference issue](https://github.com/huggingface/diffusers/issues/372). -- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this might be related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039#issuecomment-1237735249), but we need to investigate in more depth. For now, we recommend to iterate instead of batching. - -## Performance - -These are the results we got on a M1 Max MacBook Pro with 64 GB of RAM, running macOS Ventura Version 13.0 Beta (22A5331f). We performed Stable Diffusion text-to-image generation of the same prompt for 50 inference steps, using a guidance scale of 7.5. - -| Device | Steps | Time | -|--------|-------|---------| -| CPU | 50 | 213.46s | -| MPS | 50 | 30.81s | \ No newline at end of file +- Generating multiple prompts in a batch [crashes or doesn't work reliably](https://github.com/huggingface/diffusers/issues/363). We believe this is related to the [`mps` backend in PyTorch](https://github.com/pytorch/pytorch/issues/84039). For now, we recommend to iterate instead of batching. diff --git a/examples/community/clip_guided_stable_diffusion.py b/examples/community/clip_guided_stable_diffusion.py index 974f4ab2..2c86e913 100644 --- a/examples/community/clip_guided_stable_diffusion.py +++ b/examples/community/clip_guided_stable_diffusion.py @@ -249,7 +249,7 @@ class CLIPGuidedStableDiffusion(DiffusionPipeline): latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device ) diff --git a/examples/community/interpolate_stable_diffusion.py b/examples/community/interpolate_stable_diffusion.py index 97116bdc..bbb1b0f9 100644 --- a/examples/community/interpolate_stable_diffusion.py +++ b/examples/community/interpolate_stable_diffusion.py @@ -324,7 +324,7 @@ class StableDiffusionWalkPipeline(DiffusionPipeline): latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device ) diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index 4906e10f..dce30c6a 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -207,7 +207,6 @@ class BasicTransformerBlock(nn.Module): self.attn2._slice_size = slice_size def forward(self, hidden_states, context=None): - hidden_states = hidden_states.contiguous() if hidden_states.device.type == "mps" else hidden_states hidden_states = self.attn1(self.norm1(hidden_states)) + hidden_states hidden_states = self.attn2(self.norm2(hidden_states), context=context) + hidden_states hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states @@ -288,10 +287,19 @@ class CrossAttention(nn.Module): def _attention(self, query, key, value): # TODO: use baddbmm for better performance - attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale + if query.device.type == "mps": + # Better performance on mps (~20-25%) + attention_scores = torch.einsum("b i d, b j d -> b i j", query, key) * self.scale + else: + attention_scores = torch.matmul(query, key.transpose(-1, -2)) * self.scale attention_probs = attention_scores.softmax(dim=-1) # compute attention output - hidden_states = torch.matmul(attention_probs, value) + + if query.device.type == "mps": + hidden_states = torch.einsum("b i j, b j d -> b i d", attention_probs, value) + else: + hidden_states = torch.matmul(attention_probs, value) + # reshape hidden_states hidden_states = self.reshape_batch_dim_to_heads(hidden_states) return hidden_states @@ -305,11 +313,21 @@ class CrossAttention(nn.Module): for i in range(hidden_states.shape[0] // slice_size): start_idx = i * slice_size end_idx = (i + 1) * slice_size - attn_slice = ( - torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale - ) # TODO: use baddbmm for better performance + if query.device.type == "mps": + # Better performance on mps (~20-25%) + attn_slice = ( + torch.einsum("b i d, b j d -> b i j", query[start_idx:end_idx], key[start_idx:end_idx]) + * self.scale + ) + else: + attn_slice = ( + torch.matmul(query[start_idx:end_idx], key[start_idx:end_idx].transpose(1, 2)) * self.scale + ) # TODO: use baddbmm for better performance attn_slice = attn_slice.softmax(dim=-1) - attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) + if query.device.type == "mps": + attn_slice = torch.einsum("b i j, b j d -> b i d", attn_slice, value[start_idx:end_idx]) + else: + attn_slice = torch.matmul(attn_slice, value[start_idx:end_idx]) hidden_states[start_idx:end_idx] = attn_slice diff --git a/src/diffusers/models/resnet.py b/src/diffusers/models/resnet.py index d4cb367e..fbd78b51 100644 --- a/src/diffusers/models/resnet.py +++ b/src/diffusers/models/resnet.py @@ -492,10 +492,6 @@ def upfirdn2d_native(tensor, kernel, up=1, down=1, pad=(0, 0)): kernel_h, kernel_w = kernel.shape out = tensor.view(-1, in_h, 1, in_w, 1, minor) - - # Temporary workaround for mps specific issue: https://github.com/pytorch/pytorch/issues/84535 - if tensor.device.type == "mps": - out = out.to("cpu") out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1]) out = out.view(-1, in_h * up_y, in_w * up_x, minor) diff --git a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py index 46def26e..02a6b45f 100644 --- a/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +++ b/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py @@ -287,7 +287,7 @@ class StableDiffusionPipeline(DiffusionPipeline): latents_dtype = text_embeddings.dtype if latents is None: if self.device.type == "mps": - # randn does not exist on mps + # randn does not work reproducibly on mps latents = torch.randn(latents_shape, generator=generator, device="cpu", dtype=latents_dtype).to( self.device )