xFormers attention op arg (#2049)
* allow passing op to xFormers attention original code by @patil-suraj huggingface/diffusers@ae0cc0b71f * correct style by `make style` * add attention_op arg documents * add usage example to docstring Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add usage example to docstring Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * code style correction by `make style` * Update docstring code to a valid python example Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update docstring code to a valid python example Co-authored-by: Suraj Patil <surajp815@gmail.com> * style correction by `make style` * Update code exmaple to fully functional Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
7533e3d7e6
commit
16bb5058b9
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import math
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -72,6 +72,7 @@ class AttentionBlock(nn.Module):
|
|||
self.proj_attn = nn.Linear(channels, channels, 1)
|
||||
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
self._attention_op = None
|
||||
|
||||
def reshape_heads_to_batch_dim(self, tensor):
|
||||
batch_size, seq_len, dim = tensor.shape
|
||||
|
@ -87,7 +88,9 @@ class AttentionBlock(nn.Module):
|
|||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
):
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
|
@ -113,6 +116,7 @@ class AttentionBlock(nn.Module):
|
|||
except Exception as e:
|
||||
raise e
|
||||
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||
self._attention_op = attention_op
|
||||
|
||||
def forward(self, hidden_states):
|
||||
residual = hidden_states
|
||||
|
@ -136,7 +140,9 @@ class AttentionBlock(nn.Module):
|
|||
|
||||
if self._use_memory_efficient_attention_xformers:
|
||||
# Memory efficient attention
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
|
||||
)
|
||||
hidden_states = hidden_states.to(query_proj.dtype)
|
||||
else:
|
||||
attention_scores = torch.baddbmm(
|
||||
|
|
|
@ -11,7 +11,7 @@
|
|||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Optional, Union
|
||||
from typing import Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
@ -93,7 +93,9 @@ class CrossAttention(nn.Module):
|
|||
processor = processor if processor is not None else CrossAttnProcessor()
|
||||
self.set_processor(processor)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||
):
|
||||
if use_memory_efficient_attention_xformers:
|
||||
if self.added_kv_proj_dim is not None:
|
||||
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
||||
|
@ -127,7 +129,7 @@ class CrossAttention(nn.Module):
|
|||
except Exception as e:
|
||||
raise e
|
||||
|
||||
processor = XFormersCrossAttnProcessor()
|
||||
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
|
||||
else:
|
||||
processor = CrossAttnProcessor()
|
||||
|
||||
|
@ -351,6 +353,9 @@ class CrossAttnAddedKVProcessor:
|
|||
|
||||
|
||||
class XFormersCrossAttnProcessor:
|
||||
def __init__(self, attention_op: Optional[Callable] = None):
|
||||
self.attention_op = attention_op
|
||||
|
||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
|
@ -366,7 +371,9 @@ class XFormersCrossAttnProcessor:
|
|||
key = attn.head_to_batch_dim(key).contiguous()
|
||||
value = attn.head_to_batch_dim(value).contiguous()
|
||||
|
||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
||||
hidden_states = xformers.ops.memory_efficient_attention(
|
||||
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||
)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
|
|
|
@ -190,13 +190,15 @@ class ModelMixin(torch.nn.Module):
|
|||
if self._supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
@ -205,7 +207,7 @@ class ModelMixin(torch.nn.Module):
|
|||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
|
@ -214,8 +216,28 @@ class ModelMixin(torch.nn.Module):
|
|||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
|
||||
Parameters:
|
||||
attention_op (`Callable`, *optional*):
|
||||
Override the default `None` operator for use as `op` argument to the
|
||||
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
||||
function of xFormers.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import UNet2DConditionModel
|
||||
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
||||
|
||||
>>> model = UNet2DConditionModel.from_pretrained(
|
||||
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
||||
... )
|
||||
>>> model = model.to("cuda")
|
||||
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
||||
```
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(True)
|
||||
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
|
|
|
@ -19,7 +19,7 @@ import inspect
|
|||
import os
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -842,7 +842,7 @@ class DiffusionPipeline(ConfigMixin):
|
|||
def set_progress_bar_config(self, **kwargs):
|
||||
self._progress_bar_config = kwargs
|
||||
|
||||
def enable_xformers_memory_efficient_attention(self):
|
||||
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||
r"""
|
||||
Enable memory efficient attention as implemented in xformers.
|
||||
|
||||
|
@ -851,8 +851,28 @@ class DiffusionPipeline(ConfigMixin):
|
|||
|
||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||
is used.
|
||||
|
||||
Parameters:
|
||||
attention_op (`Callable`, *optional*):
|
||||
Override the default `None` operator for use as `op` argument to the
|
||||
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
||||
function of xFormers.
|
||||
|
||||
Examples:
|
||||
|
||||
```py
|
||||
>>> import torch
|
||||
>>> from diffusers import DiffusionPipeline
|
||||
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
||||
|
||||
>>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
|
||||
>>> pipe = pipe.to("cuda")
|
||||
>>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
||||
>>> # Workaround for not accepting attention shape using VAE for Flash Attention
|
||||
>>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
|
||||
```
|
||||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(True)
|
||||
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
||||
|
||||
def disable_xformers_memory_efficient_attention(self):
|
||||
r"""
|
||||
|
@ -860,13 +880,15 @@ class DiffusionPipeline(ConfigMixin):
|
|||
"""
|
||||
self.set_use_memory_efficient_attention_xformers(False)
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
||||
def set_use_memory_efficient_attention_xformers(
|
||||
self, valid: bool, attention_op: Optional[Callable] = None
|
||||
) -> None:
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||
# gets the message
|
||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||
module.set_use_memory_efficient_attention_xformers(valid)
|
||||
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_set_mem_eff(child)
|
||||
|
|
Loading…
Reference in New Issue