Add xformers attention to VAE (#1507)

* Add xformers attention to VAE

* Simplify VAE xformers code

* Update src/diffusers/models/attention.py

Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Ilmari Heikkinen 2022-12-03 22:08:11 +08:00 committed by GitHub
parent ae368e42d2
commit daebee0963
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 46 additions and 15 deletions

View File

@ -286,6 +286,32 @@ class AttentionBlock(nn.Module):
self.rescale_output_factor = rescale_output_factor
self.proj_attn = nn.Linear(channels, channels, 1)
self._use_memory_efficient_attention_xformers = False
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
if not is_xformers_available():
raise ModuleNotFoundError(
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
" xformers",
name="xformers",
)
elif not torch.cuda.is_available():
raise ValueError(
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
" available for GPU "
)
else:
try:
# Make sure we can run the memory efficient attention
_ = xformers.ops.memory_efficient_attention(
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
torch.randn((1, 2, 40), device="cuda"),
)
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
head_size = self.num_heads
@ -320,21 +346,26 @@ class AttentionBlock(nn.Module):
key_proj = self.reshape_heads_to_batch_dim(key_proj)
value_proj = self.reshape_heads_to_batch_dim(value_proj)
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)
if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(
torch.empty(
query_proj.shape[0],
query_proj.shape[1],
key_proj.shape[1],
dtype=query_proj.dtype,
device=query_proj.device,
),
query_proj,
key_proj.transpose(-1, -2),
beta=0,
alpha=scale,
)
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
hidden_states = torch.bmm(attention_probs, value_proj)
# reshape hidden_states
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)