xFormers attention op arg (#2049)

* allow passing op to xFormers attention

original code by @patil-suraj
huggingface/diffusers@ae0cc0b71f

* correct style by `make style`

* add attention_op arg documents

* add usage example to docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* add usage example to docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* code style correction by `make style`

* Update docstring code to a valid python example

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update docstring code to a valid python example

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* style correction by `make style`

* Update code exmaple to fully functional

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Takuma Mori 2023-01-25 01:26:04 +09:00 committed by GitHub
parent 7533e3d7e6
commit 16bb5058b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 16 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import math import math
from typing import Optional from typing import Callable, Optional
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -72,6 +72,7 @@ class AttentionBlock(nn.Module):
self.proj_attn = nn.Linear(channels, channels, 1) self.proj_attn = nn.Linear(channels, channels, 1)
self._use_memory_efficient_attention_xformers = False self._use_memory_efficient_attention_xformers = False
self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor): def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape batch_size, seq_len, dim = tensor.shape
@ -87,7 +88,9 @@ class AttentionBlock(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor return tensor
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if not is_xformers_available(): if not is_xformers_available():
raise ModuleNotFoundError( raise ModuleNotFoundError(
@ -113,6 +116,7 @@ class AttentionBlock(nn.Module):
except Exception as e: except Exception as e:
raise e raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op
def forward(self, hidden_states): def forward(self, hidden_states):
residual = hidden_states residual = hidden_states
@ -136,7 +140,9 @@ class AttentionBlock(nn.Module):
if self._use_memory_efficient_attention_xformers: if self._use_memory_efficient_attention_xformers:
# Memory efficient attention # Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
)
hidden_states = hidden_states.to(query_proj.dtype) hidden_states = hidden_states.to(query_proj.dtype)
else: else:
attention_scores = torch.baddbmm( attention_scores = torch.baddbmm(

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from typing import Optional, Union from typing import Callable, Optional, Union
import torch import torch
import torch.nn.functional as F import torch.nn.functional as F
@ -93,7 +93,9 @@ class CrossAttention(nn.Module):
processor = processor if processor is not None else CrossAttnProcessor() processor = processor if processor is not None else CrossAttnProcessor()
self.set_processor(processor) self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool): def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers: if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None: if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP # TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@ -127,7 +129,7 @@ class CrossAttention(nn.Module):
except Exception as e: except Exception as e:
raise e raise e
processor = XFormersCrossAttnProcessor() processor = XFormersCrossAttnProcessor(attention_op=attention_op)
else: else:
processor = CrossAttnProcessor() processor = CrossAttnProcessor()
@ -351,6 +353,9 @@ class CrossAttnAddedKVProcessor:
class XFormersCrossAttnProcessor: class XFormersCrossAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape batch_size, sequence_length, _ = hidden_states.shape
@ -366,7 +371,9 @@ class XFormersCrossAttnProcessor:
key = attn.head_to_batch_dim(key).contiguous() key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous() value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = hidden_states.to(query.dtype) hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states) hidden_states = attn.batch_to_head_dim(hidden_states)

View File

@ -190,13 +190,15 @@ class ModelMixin(torch.nn.Module):
if self._supports_gradient_checkpointing: if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False)) self.apply(partial(self._set_gradient_checkpointing, value=False))
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children. # Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method # Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message # gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module): def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"): if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid) module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children(): for child in module.children():
fn_recursive_set_mem_eff(child) fn_recursive_set_mem_eff(child)
@ -205,7 +207,7 @@ class ModelMixin(torch.nn.Module):
if isinstance(module, torch.nn.Module): if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module) fn_recursive_set_mem_eff(module)
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r""" r"""
Enable memory efficient attention as implemented in xformers. Enable memory efficient attention as implemented in xformers.
@ -214,8 +216,28 @@ class ModelMixin(torch.nn.Module):
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used. is used.
Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.
Examples:
```py
>>> import torch
>>> from diffusers import UNet2DConditionModel
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> model = UNet2DConditionModel.from_pretrained(
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
... )
>>> model = model.to("cuda")
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
```
""" """
self.set_use_memory_efficient_attention_xformers(True) self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self): def disable_xformers_memory_efficient_attention(self):
r""" r"""

View File

@ -19,7 +19,7 @@ import inspect
import os import os
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from typing import Any, Dict, List, Optional, Union from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np import numpy as np
import torch import torch
@ -842,7 +842,7 @@ class DiffusionPipeline(ConfigMixin):
def set_progress_bar_config(self, **kwargs): def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs self._progress_bar_config = kwargs
def enable_xformers_memory_efficient_attention(self): def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r""" r"""
Enable memory efficient attention as implemented in xformers. Enable memory efficient attention as implemented in xformers.
@ -851,8 +851,28 @@ class DiffusionPipeline(ConfigMixin):
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used. is used.
Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.
Examples:
```py
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
>>> # Workaround for not accepting attention shape using VAE for Flash Attention
>>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
```
""" """
self.set_use_memory_efficient_attention_xformers(True) self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self): def disable_xformers_memory_efficient_attention(self):
r""" r"""
@ -860,13 +880,15 @@ class DiffusionPipeline(ConfigMixin):
""" """
self.set_use_memory_efficient_attention_xformers(False) self.set_use_memory_efficient_attention_xformers(False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None: def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children. # Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method # Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message # gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module): def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"): if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid) module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children(): for child in module.children():
fn_recursive_set_mem_eff(child) fn_recursive_set_mem_eff(child)