Add xformers attention to VAE (#1507)
* Add xformers attention to VAE * Simplify VAE xformers code * Update src/diffusers/models/attention.py Co-authored-by: Ilmari Heikkinen <ilmari@fhtr.org> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
ae368e42d2
commit
daebee0963
|
@ -286,6 +286,32 @@ class AttentionBlock(nn.Module):
|
|||
self.rescale_output_factor = rescale_output_factor
|
||||
self.proj_attn = nn.Linear(channels, channels, 1)
|
||||
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"Refer to https://github.com/facebookresearch/xformers for more information on how to install"
|
||||
" xformers",
|
||||
name="xformers",
|
||||
)
|
||||
elif not torch.cuda.is_available():
|
||||
raise ValueError(
|
||||
"torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
|
||||
" available for GPU "
|
||||
)
|
||||
else:
|
||||
try:
|
||||
# Make sure we can run the memory efficient attention
|
||||
_ = xformers.ops.memory_efficient_attention(
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
torch.randn((1, 2, 40), device="cuda"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
head_size = self.num_heads
|
||||
|
@ -320,21 +346,26 @@ class AttentionBlock(nn.Module):
|
|||
key_proj = self.reshape_heads_to_batch_dim(key_proj)
|
||||
value_proj = self.reshape_heads_to_batch_dim(value_proj)
|
||||
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query_proj.shape[0],
|
||||
query_proj.shape[1],
|
||||
key_proj.shape[1],
|
||||
dtype=query_proj.dtype,
|
||||
device=query_proj.device,
|
||||
),
|
||||
query_proj,
|
||||
key_proj.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=scale,
|
||||
)
|
||||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
||||
hidden_states = torch.bmm(attention_probs, value_proj)
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
# Memory efficient attention
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
|
||||
hidden_states = hidden_states.to(query_proj.dtype)
|
||||
else:
|
||||
attention_scores = torch.baddbmm(
|
||||
torch.empty(
|
||||
query_proj.shape[0],
|
||||
query_proj.shape[1],
|
||||
key_proj.shape[1],
|
||||
dtype=query_proj.dtype,
|
||||
device=query_proj.device,
|
||||
),
|
||||
query_proj,
|
||||
key_proj.transpose(-1, -2),
|
||||
beta=0,
|
||||
alpha=scale,
|
||||
)
|
||||
attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
|
||||
hidden_states = torch.bmm(attention_probs, value_proj)
|
||||
|
||||
# reshape hidden_states
|
||||
hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
|
||||
|
|
Loading…
Reference in New Issue