[refactor] make set_attention_slice recursive (#1532)
* make attn slice recursive * remove set_attention_slice from blocks * fix copies * make enable_attention_slicing base class method of DiffusionPipeline * fix set_attention_slice * fix set_attention_slice * fix copies * add tests * up * up * up * update * up * uP Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com>
This commit is contained in:
parent
e289998932
commit
bce65cd13a
|
@ -174,10 +174,6 @@ class Transformer2DModel(ModelMixin, ConfigMixin):
|
|||
self.norm_out = nn.LayerNorm(inner_dim)
|
||||
self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
for block in self.transformer_blocks:
|
||||
block._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
|
||||
"""
|
||||
Args:
|
||||
|
@ -448,10 +444,6 @@ class BasicTransformerBlock(nn.Module):
|
|||
f" correctly and a GPU is available: {e}"
|
||||
)
|
||||
|
||||
def _set_attention_slice(self, slice_size):
|
||||
self.attn1._slice_size = slice_size
|
||||
self.attn2._slice_size = slice_size
|
||||
|
||||
def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool):
|
||||
if not is_xformers_available():
|
||||
print("Here is how to install it")
|
||||
|
@ -534,6 +526,7 @@ class CrossAttention(nn.Module):
|
|||
# for slice_size > 0 the attention score computation
|
||||
# is split across the batch axis to save memory
|
||||
# You can set slice_size with `set_attention_slice`
|
||||
self.sliceable_head_dim = heads
|
||||
self._slice_size = None
|
||||
self._use_memory_efficient_attention_xformers = False
|
||||
|
||||
|
@ -559,6 +552,12 @@ class CrossAttention(nn.Module):
|
|||
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
||||
return tensor
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and slice_size > self.sliceable_head_dim:
|
||||
raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
|
||||
|
||||
self._slice_size = slice_size
|
||||
|
||||
def forward(self, hidden_states, context=None, mask=None):
|
||||
batch_size, sequence_length, _ = hidden_states.shape
|
||||
|
||||
|
|
|
@ -401,23 +401,6 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
|
@ -595,23 +578,6 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
|
@ -1190,25 +1156,6 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
|
|
@ -12,7 +12,7 @@
|
|||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -229,28 +229,69 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.config.attention_head_dim
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_slicable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_slicable_dims(module)
|
||||
|
||||
num_slicable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_slicable_layers * [1]
|
||||
|
||||
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for block in self.down_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
self.mid_block.set_attention_slice(slice_size)
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for block in self.up_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
||||
|
|
|
@ -839,3 +839,34 @@ class DiffusionPipeline(ConfigMixin):
|
|||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module):
|
||||
fn_recursive_set_mem_eff(module)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
self.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def set_attention_slice(self, slice_size: Optional[int]):
|
||||
module_names, _, _ = self.extract_init_dict(dict(self.config))
|
||||
for module_name in module_names:
|
||||
module = getattr(self, module_name)
|
||||
if isinstance(module, torch.nn.Module) and hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size)
|
||||
|
|
|
@ -166,38 +166,6 @@ class AltDiffusionPipeline(DiffusionPipeline):
|
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
|
|
@ -179,38 +179,6 @@ class AltDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
|
|
|
@ -209,40 +209,6 @@ class CycleDiffusionPipeline(DiffusionPipeline):
|
|||
)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
|
|
|
@ -165,38 +165,6 @@ class StableDiffusionPipeline(DiffusionPipeline):
|
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_vae_slicing(self):
|
||||
r"""
|
||||
Enable sliced VAE decoding.
|
||||
|
|
|
@ -134,40 +134,6 @@ class StableDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
|
|
|
@ -178,40 +178,6 @@ class StableDiffusionImg2ImgPipeline(DiffusionPipeline):
|
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
|
|
|
@ -243,40 +243,6 @@ class StableDiffusionInpaintPipeline(DiffusionPipeline):
|
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
|
|
|
@ -191,40 +191,6 @@ class StableDiffusionInpaintPipelineLegacy(DiffusionPipeline):
|
|||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_sequential_cpu_offload
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
|
|
|
@ -92,40 +92,6 @@ class StableDiffusionUpscalePipeline(DiffusionPipeline):
|
|||
)
|
||||
self.register_to_config(max_noise_level=max_noise_level)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.unet.config.attention_head_dim)
|
||||
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
|
|
|
@ -182,33 +182,6 @@ class StableDiffusionPipelineSafe(DiffusionPipeline):
|
|||
"""
|
||||
self._safety_text_concept = concept
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.unet.config.attention_head_dim // 2
|
||||
self.unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
@ -307,28 +307,69 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|||
self.conv_out = LinearMultiDim(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.config.attention_head_dim
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
|
||||
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
|
||||
must be a multiple of `slice_size`.
|
||||
"""
|
||||
sliceable_head_dims = []
|
||||
|
||||
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
sliceable_head_dims.append(module.sliceable_head_dim)
|
||||
|
||||
for child in module.children():
|
||||
fn_recursive_retrieve_slicable_dims(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in self.children():
|
||||
fn_recursive_retrieve_slicable_dims(module)
|
||||
|
||||
num_slicable_layers = len(sliceable_head_dims)
|
||||
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = [dim // 2 for dim in sliceable_head_dims]
|
||||
elif slice_size == "max":
|
||||
# make smallest slice possible
|
||||
slice_size = num_slicable_layers * [1]
|
||||
|
||||
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
|
||||
|
||||
if len(slice_size) != len(sliceable_head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
|
||||
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
|
||||
)
|
||||
|
||||
for block in self.down_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
for i in range(len(slice_size)):
|
||||
size = slice_size[i]
|
||||
dim = sliceable_head_dims[i]
|
||||
if size is not None and size > dim:
|
||||
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
|
||||
|
||||
self.mid_block.set_attention_slice(slice_size)
|
||||
# Recursively walk through all the children.
|
||||
# Any children which exposes the set_attention_slice method
|
||||
# gets the message
|
||||
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
module.set_attention_slice(slice_size.pop())
|
||||
|
||||
for block in self.up_blocks:
|
||||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
for child in module.children():
|
||||
fn_recursive_set_attention_slice(child, slice_size)
|
||||
|
||||
reversed_slice_size = list(reversed(slice_size))
|
||||
for module in self.children():
|
||||
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
|
||||
|
@ -739,23 +780,6 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
output_states = ()
|
||||
|
||||
|
@ -948,25 +972,6 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
|
@ -1092,23 +1097,6 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|||
self.attentions = nn.ModuleList(attentions)
|
||||
self.resnets = nn.ModuleList(resnets)
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
head_dims = self.attn_num_head_channels
|
||||
head_dims = [head_dims] if isinstance(head_dims, int) else head_dims
|
||||
if slice_size is not None and any(dim % slice_size != 0 for dim in head_dims):
|
||||
raise ValueError(
|
||||
f"Make sure slice_size {slice_size} is a common divisor of "
|
||||
f"the number of heads used in cross_attention: {head_dims}"
|
||||
)
|
||||
if slice_size is not None and slice_size > min(head_dims):
|
||||
raise ValueError(
|
||||
f"slice_size {slice_size} has to be smaller or equal to "
|
||||
f"the lowest number of heads used in cross_attention: min({head_dims}) = {min(head_dims)}"
|
||||
)
|
||||
|
||||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
|
||||
hidden_states = self.resnets[0](hidden_states, temb)
|
||||
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
||||
|
|
|
@ -80,34 +80,6 @@ class VersatileDiffusionPipeline(DiffusionPipeline):
|
|||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
self.text_unet.set_attention_slice(slice_size)
|
||||
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
@torch.no_grad()
|
||||
def image_variation(
|
||||
self,
|
||||
|
|
|
@ -147,40 +147,6 @@ class VersatileDiffusionDualGuidedPipeline(DiffusionPipeline):
|
|||
|
||||
self.image_unet.register_to_config(dual_cross_attention=False)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.image_unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.image_unet.config.attention_head_dim)
|
||||
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
|
|
|
@ -73,40 +73,6 @@ class VersatileDiffusionImageVariationPipeline(DiffusionPipeline):
|
|||
)
|
||||
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.image_unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.image_unet.config.attention_head_dim)
|
||||
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
|
|
|
@ -98,40 +98,6 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
|
|||
def remove_unused_weights(self):
|
||||
self.register_modules(text_unet=None)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.enable_attention_slicing with unet->image_unet
|
||||
def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
|
||||
r"""
|
||||
Enable sliced attention computation.
|
||||
|
||||
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
|
||||
in several steps. This is useful to save some memory in exchange for a small speed decrease.
|
||||
|
||||
Args:
|
||||
slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
|
||||
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
|
||||
a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
|
||||
`attention_head_dim` must be a multiple of `slice_size`.
|
||||
"""
|
||||
if slice_size == "auto":
|
||||
if isinstance(self.image_unet.config.attention_head_dim, int):
|
||||
# half the attention head size is usually a good trade-off between
|
||||
# speed and memory
|
||||
slice_size = self.image_unet.config.attention_head_dim // 2
|
||||
else:
|
||||
# if `attention_head_dim` is a list, take the smallest head size
|
||||
slice_size = min(self.image_unet.config.attention_head_dim)
|
||||
|
||||
self.image_unet.set_attention_slice(slice_size)
|
||||
|
||||
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.disable_attention_slicing
|
||||
def disable_attention_slicing(self):
|
||||
r"""
|
||||
Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
|
||||
back to computing attention in one step.
|
||||
"""
|
||||
# set slice_size = `None` to disable `attention slicing`
|
||||
self.enable_attention_slicing(None)
|
||||
|
||||
def enable_sequential_cpu_offload(self, gpu_id=0):
|
||||
r"""
|
||||
Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
|
||||
|
|
|
@ -334,6 +334,48 @@ class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
expected_shape = inputs_dict["sample"].shape
|
||||
self.assertEqual(output.shape, expected_shape, "Input and output shapes do not match")
|
||||
|
||||
def test_model_attention_slicing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
model.eval()
|
||||
|
||||
model.set_attention_slice("auto")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice("max")
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
model.set_attention_slice(2)
|
||||
with torch.no_grad():
|
||||
output = model(**inputs_dict)
|
||||
assert output is not None
|
||||
|
||||
def test_model_slicable_head_dim(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
init_dict["attention_head_dim"] = (8, 16)
|
||||
|
||||
model = self.model_class(**init_dict)
|
||||
|
||||
def check_slicable_dim_attr(module: torch.nn.Module):
|
||||
if hasattr(module, "set_attention_slice"):
|
||||
assert isinstance(module.sliceable_head_dim, int)
|
||||
|
||||
for child in module.children():
|
||||
check_slicable_dim_attr(child)
|
||||
|
||||
# retrieve number of attention layers
|
||||
for module in model.children():
|
||||
check_slicable_dim_attr(module)
|
||||
|
||||
|
||||
class NCSNppModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DModel
|
||||
|
@ -479,6 +521,84 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
|
||||
return model
|
||||
|
||||
def test_set_attention_slice_auto(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("auto")
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_max(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice("max")
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_int(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice(2)
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def test_set_attention_slice_list(self):
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.reset_max_memory_allocated()
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
|
||||
# there are 32 slicable layers
|
||||
slice_list = 16 * [2, 3]
|
||||
unet = self.get_unet_model()
|
||||
unet.set_attention_slice(slice_list)
|
||||
|
||||
latents = self.get_latents(33)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(33)
|
||||
timestep = 1
|
||||
|
||||
with torch.no_grad():
|
||||
_ = unet(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
mem_bytes = torch.cuda.max_memory_allocated()
|
||||
|
||||
assert mem_bytes < 5 * 10**9
|
||||
|
||||
def get_encoder_hidden_states(self, seed=0, shape=(4, 77, 768), fp16=False):
|
||||
dtype = torch.float16 if fp16 else torch.float32
|
||||
hidden_states = torch.from_numpy(load_hf_numpy(self.get_file_format(seed, shape))).to(torch_device).to(dtype)
|
||||
|
@ -500,6 +620,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
latents = self.get_latents(seed)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
|
@ -526,6 +648,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
|
@ -552,6 +676,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
latents = self.get_latents(seed)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
|
@ -578,6 +704,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
latents = self.get_latents(seed, fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
|
@ -604,6 +732,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
latents = self.get_latents(seed, shape=(4, 9, 64, 64))
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
|
@ -630,6 +760,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
latents = self.get_latents(seed, shape=(4, 9, 64, 64), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
|
@ -656,6 +788,8 @@ class UNet2DConditionModelIntegrationTests(unittest.TestCase):
|
|||
latents = self.get_latents(seed, shape=(4, 4, 96, 96), fp16=True)
|
||||
encoder_hidden_states = self.get_encoder_hidden_states(seed, shape=(4, 77, 1024), fp16=True)
|
||||
|
||||
timestep = torch.tensor([timestep], dtype=torch.long, device=torch_device)
|
||||
|
||||
with torch.no_grad():
|
||||
sample = model(latents, timestep=timestep, encoder_hidden_states=encoder_hidden_states).sample
|
||||
|
||||
|
|
Loading…
Reference in New Issue