xFormers attention op arg (#2049)

* allow passing op to xFormers attention

original code by @patil-suraj
huggingface/diffusers@ae0cc0b71f

* correct style by `make style`

* add attention_op arg documents

* add usage example to docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* add usage example to docstring

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>

* code style correction by `make style`

* Update docstring code to a valid python example

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* Update docstring code to a valid python example

Co-authored-by: Suraj Patil <surajp815@gmail.com>

* style correction by `make style`

* Update code exmaple to fully functional

Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
Takuma Mori 2023-01-25 01:26:04 +09:00 committed by GitHub
parent 7533e3d7e6
commit 16bb5058b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 73 additions and 16 deletions

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Optional
from typing import Callable, Optional
import torch
import torch.nn.functional as F
@ -72,6 +72,7 @@ class AttentionBlock(nn.Module):
self.proj_attn = nn.Linear(channels, channels, 1)
self._use_memory_efficient_attention_xformers = False
self._attention_op = None
def reshape_heads_to_batch_dim(self, tensor):
batch_size, seq_len, dim = tensor.shape
@ -87,7 +88,9 @@ class AttentionBlock(nn.Module):
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
return tensor
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if not is_xformers_available():
raise ModuleNotFoundError(
@ -113,6 +116,7 @@ class AttentionBlock(nn.Module):
except Exception as e:
raise e
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
self._attention_op = attention_op
def forward(self, hidden_states):
residual = hidden_states
@ -136,7 +140,9 @@ class AttentionBlock(nn.Module):
if self._use_memory_efficient_attention_xformers:
# Memory efficient attention
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
hidden_states = xformers.ops.memory_efficient_attention(
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
)
hidden_states = hidden_states.to(query_proj.dtype)
else:
attention_scores = torch.baddbmm(

View File

@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional, Union
from typing import Callable, Optional, Union
import torch
import torch.nn.functional as F
@ -93,7 +93,9 @@ class CrossAttention(nn.Module):
processor = processor if processor is not None else CrossAttnProcessor()
self.set_processor(processor)
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
def set_use_memory_efficient_attention_xformers(
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
):
if use_memory_efficient_attention_xformers:
if self.added_kv_proj_dim is not None:
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
@ -127,7 +129,7 @@ class CrossAttention(nn.Module):
except Exception as e:
raise e
processor = XFormersCrossAttnProcessor()
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
else:
processor = CrossAttnProcessor()
@ -351,6 +353,9 @@ class CrossAttnAddedKVProcessor:
class XFormersCrossAttnProcessor:
def __init__(self, attention_op: Optional[Callable] = None):
self.attention_op = attention_op
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
batch_size, sequence_length, _ = hidden_states.shape
@ -366,7 +371,9 @@ class XFormersCrossAttnProcessor:
key = attn.head_to_batch_dim(key).contiguous()
value = attn.head_to_batch_dim(value).contiguous()
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
hidden_states = xformers.ops.memory_efficient_attention(
query, key, value, attn_bias=attention_mask, op=self.attention_op
)
hidden_states = hidden_states.to(query.dtype)
hidden_states = attn.batch_to_head_dim(hidden_states)

View File

@ -190,13 +190,15 @@ class ModelMixin(torch.nn.Module):
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children():
fn_recursive_set_mem_eff(child)
@ -205,7 +207,7 @@ class ModelMixin(torch.nn.Module):
if isinstance(module, torch.nn.Module):
fn_recursive_set_mem_eff(module)
def enable_xformers_memory_efficient_attention(self):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.
@ -214,8 +216,28 @@ class ModelMixin(torch.nn.Module):
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.
Examples:
```py
>>> import torch
>>> from diffusers import UNet2DConditionModel
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> model = UNet2DConditionModel.from_pretrained(
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
... )
>>> model = model.to("cuda")
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
```
"""
self.set_use_memory_efficient_attention_xformers(True)
self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self):
r"""

View File

@ -19,7 +19,7 @@ import inspect
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
@ -842,7 +842,7 @@ class DiffusionPipeline(ConfigMixin):
def set_progress_bar_config(self, **kwargs):
self._progress_bar_config = kwargs
def enable_xformers_memory_efficient_attention(self):
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
r"""
Enable memory efficient attention as implemented in xformers.
@ -851,8 +851,28 @@ class DiffusionPipeline(ConfigMixin):
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
is used.
Parameters:
attention_op (`Callable`, *optional*):
Override the default `None` operator for use as `op` argument to the
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
function of xFormers.
Examples:
```py
>>> import torch
>>> from diffusers import DiffusionPipeline
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
>>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
>>> pipe = pipe.to("cuda")
>>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
>>> # Workaround for not accepting attention shape using VAE for Flash Attention
>>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
```
"""
self.set_use_memory_efficient_attention_xformers(True)
self.set_use_memory_efficient_attention_xformers(True, attention_op)
def disable_xformers_memory_efficient_attention(self):
r"""
@ -860,13 +880,15 @@ class DiffusionPipeline(ConfigMixin):
"""
self.set_use_memory_efficient_attention_xformers(False)
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
def set_use_memory_efficient_attention_xformers(
self, valid: bool, attention_op: Optional[Callable] = None
) -> None:
# Recursively walk through all the children.
# Any children which exposes the set_use_memory_efficient_attention_xformers method
# gets the message
def fn_recursive_set_mem_eff(module: torch.nn.Module):
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
module.set_use_memory_efficient_attention_xformers(valid)
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
for child in module.children():
fn_recursive_set_mem_eff(child)