diff --git a/src/diffusers/models/attention.py b/src/diffusers/models/attention.py index a1fd9cbf..60f0c8bf 100644 --- a/src/diffusers/models/attention.py +++ b/src/diffusers/models/attention.py @@ -288,28 +288,29 @@ class AttentionBlock(nn.Module): self._use_memory_efficient_attention_xformers = False def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - if not is_xformers_available(): - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" - " available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", ) - except Exception as e: - raise e - self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def reshape_heads_to_batch_dim(self, tensor): batch_size, seq_len, dim = tensor.shape @@ -450,31 +451,32 @@ class BasicTransformerBlock(nn.Module): self.norm3 = nn.LayerNorm(dim) def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): - if not is_xformers_available(): - print("Here is how to install it") - raise ModuleNotFoundError( - "Refer to https://github.com/facebookresearch/xformers for more information on how to install" - " xformers", - name="xformers", - ) - elif not torch.cuda.is_available(): - raise ValueError( - "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" - " available for GPU " - ) - else: - try: - # Make sure we can run the memory efficient attention - _ = xformers.ops.memory_efficient_attention( - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), - torch.randn((1, 2, 40), device="cuda"), + if use_memory_efficient_attention_xformers: + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", ) - except Exception as e: - raise e - self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers - if self.attn2 is not None: - self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): # 1. Self-Attention