diff --git a/src/diffusers/modeling_utils.py b/src/diffusers/modeling_utils.py index 6935fc12..659f2ee8 100644 --- a/src/diffusers/modeling_utils.py +++ b/src/diffusers/modeling_utils.py @@ -15,6 +15,7 @@ # limitations under the License. import os +from functools import partial from typing import Callable, List, Optional, Tuple, Union import torch @@ -121,10 +122,42 @@ class ModelMixin(torch.nn.Module): """ config_name = CONFIG_NAME _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"] + _supports_gradient_checkpointing = False def __init__(self): super().__init__() + @property + def is_gradient_checkpointing(self) -> bool: + """ + Whether gradient checkpointing is activated for this model or not. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules()) + + def enable_gradient_checkpointing(self): + """ + Activates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if not self._supports_gradient_checkpointing: + raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.") + self.apply(partial(self._set_gradient_checkpointing, value=True)) + + def disable_gradient_checkpointing(self): + """ + Deactivates gradient checkpointing for the current model. + + Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint + activations". + """ + if self._supports_gradient_checkpointing: + self.apply(partial(self._set_gradient_checkpointing, value=False)) + def save_pretrained( self, save_directory: Union[str, os.PathLike], diff --git a/src/diffusers/models/unet_2d_condition.py b/src/diffusers/models/unet_2d_condition.py index cdb04621..5e3ee091 100644 --- a/src/diffusers/models/unet_2d_condition.py +++ b/src/diffusers/models/unet_2d_condition.py @@ -3,12 +3,21 @@ from typing import Optional, Tuple, Union import torch import torch.nn as nn +import torch.utils.checkpoint from ..configuration_utils import ConfigMixin, register_to_config from ..modeling_utils import ModelMixin from ..utils import BaseOutput from .embeddings import TimestepEmbedding, Timesteps -from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block +from .unet_blocks import ( + CrossAttnDownBlock2D, + CrossAttnUpBlock2D, + DownBlock2D, + UNetMidBlock2DCrossAttn, + UpBlock2D, + get_down_block, + get_up_block, +) @dataclass @@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. """ + _supports_gradient_checkpointing = True + @register_to_config def __init__( self, @@ -188,6 +199,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): if hasattr(block, "attentions") and block.attentions is not None: block.set_attention_slice(slice_size) + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)): + module.gradient_checkpointing = value + def forward( self, sample: torch.FloatTensor, @@ -234,7 +249,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin): for downsample_block in self.down_blocks: if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None: sample, res_samples = downsample_block( - hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, ) else: sample, res_samples = downsample_block(hidden_states=sample, temb=emb) diff --git a/src/diffusers/models/unet_blocks.py b/src/diffusers/models/unet_blocks.py index 1fee670c..f42389b9 100644 --- a/src/diffusers/models/unet_blocks.py +++ b/src/diffusers/models/unet_blocks.py @@ -527,6 +527,8 @@ class CrossAttnDownBlock2D(nn.Module): else: self.downsamplers = None + self.gradient_checkpointing = False + def set_attention_slice(self, slice_size): if slice_size is not None and self.attn_num_head_channels % slice_size != 0: raise ValueError( @@ -546,8 +548,22 @@ class CrossAttnDownBlock2D(nn.Module): output_states = () for resnet, attn in zip(self.resnets, self.attentions): - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn), hidden_states, encoder_hidden_states + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) + output_states += (hidden_states,) if self.downsamplers is not None: @@ -609,11 +625,24 @@ class DownBlock2D(nn.Module): else: self.downsamplers = None + self.gradient_checkpointing = False + def forward(self, hidden_states, temb=None): output_states = () for resnet in self.resnets: - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + output_states += (hidden_states,) if self.downsamplers is not None: @@ -1072,6 +1101,8 @@ class CrossAttnUpBlock2D(nn.Module): else: self.upsamplers = None + self.gradient_checkpointing = False + def set_attention_slice(self, slice_size): if slice_size is not None and self.attn_num_head_channels % slice_size != 0: raise ValueError( @@ -1087,15 +1118,36 @@ class CrossAttnUpBlock2D(nn.Module): for attn in self.attentions: attn._set_attention_slice(slice_size) - def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None): + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + ): for resnet, attn in zip(self.resnets, self.attentions): # pop res hidden states res_hidden_states = res_hidden_states_tuple[-1] res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) - hidden_states = attn(hidden_states, context=encoder_hidden_states) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn), hidden_states, encoder_hidden_states + ) + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, context=encoder_hidden_states) if self.upsamplers is not None: for upsampler in self.upsamplers: @@ -1150,6 +1202,8 @@ class UpBlock2D(nn.Module): else: self.upsamplers = None + self.gradient_checkpointing = False + def forward(self, hidden_states, res_hidden_states_tuple, temb=None): for resnet in self.resnets: # pop res hidden states @@ -1157,7 +1211,17 @@ class UpBlock2D(nn.Module): res_hidden_states_tuple = res_hidden_states_tuple[:-1] hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) - hidden_states = resnet(hidden_states, temb) + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) if self.upsamplers is not None: for upsampler in self.upsamplers: diff --git a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py index 8caa11db..43b734c9 100644 --- a/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +++ b/src/diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py @@ -478,7 +478,7 @@ class LDMBertEncoderLayer(nn.Module): class LDMBertPreTrainedModel(PreTrainedModel): config_class = LDMBertConfig base_model_prefix = "model" - supports_gradient_checkpointing = True + _supports_gradient_checkpointing = True _keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"] def _init_weights(self, module): diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index 1e98fc9d..b0d00b86 100644 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -246,3 +246,21 @@ class ModelTesterMixin: outputs_tuple = model(**inputs_dict, return_dict=False) recursive_check(outputs_tuple, outputs_dict) + + def test_enable_disable_gradient_checkpointing(self): + if not self.model_class._supports_gradient_checkpointing: + return # Skip test if model does not support gradient checkpointing + + init_dict, _ = self.prepare_init_args_and_inputs_for_common() + + # at init model should have gradient checkpointing disabled + model = self.model_class(**init_dict) + self.assertFalse(model.is_gradient_checkpointing) + + # check enable works + model.enable_gradient_checkpointing() + self.assertTrue(model.is_gradient_checkpointing) + + # check disable works + model.disable_gradient_checkpointing() + self.assertFalse(model.is_gradient_checkpointing) diff --git a/tests/test_models_unet.py b/tests/test_models_unet.py index 94e7cd5b..80055c1a 100644 --- a/tests/test_models_unet.py +++ b/tests/test_models_unet.py @@ -18,7 +18,7 @@ import unittest import torch -from diffusers import UNet2DModel +from diffusers import UNet2DConditionModel, UNet2DModel from diffusers.testing_utils import floats_tensor, slow, torch_device from .test_modeling_common import ModelTesterMixin @@ -159,6 +159,82 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase): self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3)) +class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase): + model_class = UNet2DConditionModel + + @property + def dummy_input(self): + batch_size = 4 + num_channels = 4 + sizes = (32, 32) + + noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device) + time_step = torch.tensor([10]).to(torch_device) + encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device) + + return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states} + + @property + def input_shape(self): + return (4, 32, 32) + + @property + def output_shape(self): + return (4, 32, 32) + + def prepare_init_args_and_inputs_for_common(self): + init_dict = { + "block_out_channels": (32, 64), + "down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"), + "up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"), + "cross_attention_dim": 32, + "attention_head_dim": 8, + "out_channels": 4, + "in_channels": 4, + "layers_per_block": 2, + "sample_size": 32, + } + inputs_dict = self.dummy_input + return init_dict, inputs_dict + + def test_gradient_checkpointing(self): + init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common() + model = self.model_class(**init_dict) + model.to(torch_device) + + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + out.sum().backward() + + # now we save the output and parameter gradients that we will use for comparison purposes with + # the non-checkpointed run. + output_not_checkpointed = out.data.clone() + grad_not_checkpointed = {} + for name, param in model.named_parameters(): + grad_not_checkpointed[name] = param.grad.data.clone() + + model.enable_gradient_checkpointing() + out = model(**inputs_dict).sample + # run the backwards pass on the model. For backwards pass, for simplicity purpose, + # we won't calculate the loss and rather backprop on out.sum() + model.zero_grad() + out.sum().backward() + + # now we save the output and parameter gradients that we will use for comparison purposes with + # the non-checkpointed run. + output_checkpointed = out.data.clone() + grad_checkpointed = {} + for name, param in model.named_parameters(): + grad_checkpointed[name] = param.grad.data.clone() + + # compare the output and parameters gradients + self.assertTrue((output_checkpointed == output_not_checkpointed).all()) + for name in grad_checkpointed: + self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5)) + + # TODO(Patrick) - Re-add this test after having cleaned up LDM # def test_output_pretrained_spatial_transformer(self): # model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")