xFormers attention op arg (#2049)
* allow passing op to xFormers attention original code by @patil-suraj huggingface/diffusers@ae0cc0b71f * correct style by `make style` * add attention_op arg documents * add usage example to docstring Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * add usage example to docstring Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * code style correction by `make style` * Update docstring code to a valid python example Co-authored-by: Suraj Patil <surajp815@gmail.com> * Update docstring code to a valid python example Co-authored-by: Suraj Patil <surajp815@gmail.com> * style correction by `make style` * Update code exmaple to fully functional Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Suraj Patil <surajp815@gmail.com>
This commit is contained in:
parent
7533e3d7e6
commit
16bb5058b9
|
@ -12,7 +12,7 @@
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import math
|
import math
|
||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -72,6 +72,7 @@ class AttentionBlock(nn.Module):
|
||||||
self.proj_attn = nn.Linear(channels, channels, 1)
|
self.proj_attn = nn.Linear(channels, channels, 1)
|
||||||
|
|
||||||
self._use_memory_efficient_attention_xformers = False
|
self._use_memory_efficient_attention_xformers = False
|
||||||
|
self._attention_op = None
|
||||||
|
|
||||||
def reshape_heads_to_batch_dim(self, tensor):
|
def reshape_heads_to_batch_dim(self, tensor):
|
||||||
batch_size, seq_len, dim = tensor.shape
|
batch_size, seq_len, dim = tensor.shape
|
||||||
|
@ -87,7 +88,9 @@ class AttentionBlock(nn.Module):
|
||||||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||||
return tensor
|
return tensor
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
def set_use_memory_efficient_attention_xformers(
|
||||||
|
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||||
|
):
|
||||||
if use_memory_efficient_attention_xformers:
|
if use_memory_efficient_attention_xformers:
|
||||||
if not is_xformers_available():
|
if not is_xformers_available():
|
||||||
raise ModuleNotFoundError(
|
raise ModuleNotFoundError(
|
||||||
|
@ -113,6 +116,7 @@ class AttentionBlock(nn.Module):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
|
||||||
|
self._attention_op = attention_op
|
||||||
|
|
||||||
def forward(self, hidden_states):
|
def forward(self, hidden_states):
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
@ -136,7 +140,9 @@ class AttentionBlock(nn.Module):
|
||||||
|
|
||||||
if self._use_memory_efficient_attention_xformers:
|
if self._use_memory_efficient_attention_xformers:
|
||||||
# Memory efficient attention
|
# Memory efficient attention
|
||||||
hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
|
hidden_states = xformers.ops.memory_efficient_attention(
|
||||||
|
query_proj, key_proj, value_proj, attn_bias=None, op=self._attention_op
|
||||||
|
)
|
||||||
hidden_states = hidden_states.to(query_proj.dtype)
|
hidden_states = hidden_states.to(query_proj.dtype)
|
||||||
else:
|
else:
|
||||||
attention_scores = torch.baddbmm(
|
attention_scores = torch.baddbmm(
|
||||||
|
|
|
@ -11,7 +11,7 @@
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
from typing import Optional, Union
|
from typing import Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
@ -93,7 +93,9 @@ class CrossAttention(nn.Module):
|
||||||
processor = processor if processor is not None else CrossAttnProcessor()
|
processor = processor if processor is not None else CrossAttnProcessor()
|
||||||
self.set_processor(processor)
|
self.set_processor(processor)
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
def set_use_memory_efficient_attention_xformers(
|
||||||
|
self, use_memory_efficient_attention_xformers: bool, attention_op: Optional[Callable] = None
|
||||||
|
):
|
||||||
if use_memory_efficient_attention_xformers:
|
if use_memory_efficient_attention_xformers:
|
||||||
if self.added_kv_proj_dim is not None:
|
if self.added_kv_proj_dim is not None:
|
||||||
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
# TODO(Anton, Patrick, Suraj, William) - currently xformers doesn't work for UnCLIP
|
||||||
|
@ -127,7 +129,7 @@ class CrossAttention(nn.Module):
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
processor = XFormersCrossAttnProcessor()
|
processor = XFormersCrossAttnProcessor(attention_op=attention_op)
|
||||||
else:
|
else:
|
||||||
processor = CrossAttnProcessor()
|
processor = CrossAttnProcessor()
|
||||||
|
|
||||||
|
@ -351,6 +353,9 @@ class CrossAttnAddedKVProcessor:
|
||||||
|
|
||||||
|
|
||||||
class XFormersCrossAttnProcessor:
|
class XFormersCrossAttnProcessor:
|
||||||
|
def __init__(self, attention_op: Optional[Callable] = None):
|
||||||
|
self.attention_op = attention_op
|
||||||
|
|
||||||
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
||||||
batch_size, sequence_length, _ = hidden_states.shape
|
batch_size, sequence_length, _ = hidden_states.shape
|
||||||
|
|
||||||
|
@ -366,7 +371,9 @@ class XFormersCrossAttnProcessor:
|
||||||
key = attn.head_to_batch_dim(key).contiguous()
|
key = attn.head_to_batch_dim(key).contiguous()
|
||||||
value = attn.head_to_batch_dim(value).contiguous()
|
value = attn.head_to_batch_dim(value).contiguous()
|
||||||
|
|
||||||
hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
|
hidden_states = xformers.ops.memory_efficient_attention(
|
||||||
|
query, key, value, attn_bias=attention_mask, op=self.attention_op
|
||||||
|
)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
|
|
|
@ -190,13 +190,15 @@ class ModelMixin(torch.nn.Module):
|
||||||
if self._supports_gradient_checkpointing:
|
if self._supports_gradient_checkpointing:
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
def set_use_memory_efficient_attention_xformers(
|
||||||
|
self, valid: bool, attention_op: Optional[Callable] = None
|
||||||
|
) -> None:
|
||||||
# Recursively walk through all the children.
|
# Recursively walk through all the children.
|
||||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||||
# gets the message
|
# gets the message
|
||||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||||
module.set_use_memory_efficient_attention_xformers(valid)
|
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
||||||
|
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
fn_recursive_set_mem_eff(child)
|
fn_recursive_set_mem_eff(child)
|
||||||
|
@ -205,7 +207,7 @@ class ModelMixin(torch.nn.Module):
|
||||||
if isinstance(module, torch.nn.Module):
|
if isinstance(module, torch.nn.Module):
|
||||||
fn_recursive_set_mem_eff(module)
|
fn_recursive_set_mem_eff(module)
|
||||||
|
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||||
r"""
|
r"""
|
||||||
Enable memory efficient attention as implemented in xformers.
|
Enable memory efficient attention as implemented in xformers.
|
||||||
|
|
||||||
|
@ -214,8 +216,28 @@ class ModelMixin(torch.nn.Module):
|
||||||
|
|
||||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||||
is used.
|
is used.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
attention_op (`Callable`, *optional*):
|
||||||
|
Override the default `None` operator for use as `op` argument to the
|
||||||
|
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
||||||
|
function of xFormers.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> import torch
|
||||||
|
>>> from diffusers import UNet2DConditionModel
|
||||||
|
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
||||||
|
|
||||||
|
>>> model = UNet2DConditionModel.from_pretrained(
|
||||||
|
... "stabilityai/stable-diffusion-2-1", subfolder="unet", torch_dtype=torch.float16
|
||||||
|
... )
|
||||||
|
>>> model = model.to("cuda")
|
||||||
|
>>> model.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
self.set_use_memory_efficient_attention_xformers(True)
|
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
||||||
|
|
||||||
def disable_xformers_memory_efficient_attention(self):
|
def disable_xformers_memory_efficient_attention(self):
|
||||||
r"""
|
r"""
|
||||||
|
|
|
@ -19,7 +19,7 @@ import inspect
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Callable, Dict, List, Optional, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
@ -842,7 +842,7 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
def set_progress_bar_config(self, **kwargs):
|
def set_progress_bar_config(self, **kwargs):
|
||||||
self._progress_bar_config = kwargs
|
self._progress_bar_config = kwargs
|
||||||
|
|
||||||
def enable_xformers_memory_efficient_attention(self):
|
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
||||||
r"""
|
r"""
|
||||||
Enable memory efficient attention as implemented in xformers.
|
Enable memory efficient attention as implemented in xformers.
|
||||||
|
|
||||||
|
@ -851,8 +851,28 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
|
|
||||||
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
|
||||||
is used.
|
is used.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
attention_op (`Callable`, *optional*):
|
||||||
|
Override the default `None` operator for use as `op` argument to the
|
||||||
|
[`memory_efficient_attention()`](https://facebookresearch.github.io/xformers/components/ops.html#xformers.ops.memory_efficient_attention)
|
||||||
|
function of xFormers.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
```py
|
||||||
|
>>> import torch
|
||||||
|
>>> from diffusers import DiffusionPipeline
|
||||||
|
>>> from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
|
||||||
|
|
||||||
|
>>> pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-2-1", torch_dtype=torch.float16)
|
||||||
|
>>> pipe = pipe.to("cuda")
|
||||||
|
>>> pipe.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
|
||||||
|
>>> # Workaround for not accepting attention shape using VAE for Flash Attention
|
||||||
|
>>> pipe.vae.enable_xformers_memory_efficient_attention(attention_op=None)
|
||||||
|
```
|
||||||
"""
|
"""
|
||||||
self.set_use_memory_efficient_attention_xformers(True)
|
self.set_use_memory_efficient_attention_xformers(True, attention_op)
|
||||||
|
|
||||||
def disable_xformers_memory_efficient_attention(self):
|
def disable_xformers_memory_efficient_attention(self):
|
||||||
r"""
|
r"""
|
||||||
|
@ -860,13 +880,15 @@ class DiffusionPipeline(ConfigMixin):
|
||||||
"""
|
"""
|
||||||
self.set_use_memory_efficient_attention_xformers(False)
|
self.set_use_memory_efficient_attention_xformers(False)
|
||||||
|
|
||||||
def set_use_memory_efficient_attention_xformers(self, valid: bool) -> None:
|
def set_use_memory_efficient_attention_xformers(
|
||||||
|
self, valid: bool, attention_op: Optional[Callable] = None
|
||||||
|
) -> None:
|
||||||
# Recursively walk through all the children.
|
# Recursively walk through all the children.
|
||||||
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
# Any children which exposes the set_use_memory_efficient_attention_xformers method
|
||||||
# gets the message
|
# gets the message
|
||||||
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
def fn_recursive_set_mem_eff(module: torch.nn.Module):
|
||||||
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
if hasattr(module, "set_use_memory_efficient_attention_xformers"):
|
||||||
module.set_use_memory_efficient_attention_xformers(valid)
|
module.set_use_memory_efficient_attention_xformers(valid, attention_op)
|
||||||
|
|
||||||
for child in module.children():
|
for child in module.children():
|
||||||
fn_recursive_set_mem_eff(child)
|
fn_recursive_set_mem_eff(child)
|
||||||
|
|
Loading…
Reference in New Issue