[UNet2DConditionModel] add gradient checkpointing (#461)

* add grad ckpt to downsample blocks

* make it work

* don't pass gradient_checkpointing to upsample block

* add tests for UNet2DConditionModel

* add test_gradient_checkpointing

* add gradient_checkpointing for up and down blocks

* add functions to enable and disable grad ckpt

* remove the forward argument

* better naming

* make supports_gradient_checkpointing private
This commit is contained in:
Suraj Patil 2022-09-22 15:36:47 +02:00 committed by GitHub
parent 534512bedb
commit e7120bae95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 219 additions and 11 deletions

View File

@ -15,6 +15,7 @@
# limitations under the License.
import os
from functools import partial
from typing import Callable, List, Optional, Tuple, Union
import torch
@ -121,10 +122,42 @@ class ModelMixin(torch.nn.Module):
"""
config_name = CONFIG_NAME
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
_supports_gradient_checkpointing = False
def __init__(self):
super().__init__()
@property
def is_gradient_checkpointing(self) -> bool:
"""
Whether gradient checkpointing is activated for this model or not.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
def enable_gradient_checkpointing(self):
"""
Activates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if not self._supports_gradient_checkpointing:
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
self.apply(partial(self._set_gradient_checkpointing, value=True))
def disable_gradient_checkpointing(self):
"""
Deactivates gradient checkpointing for the current model.
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
activations".
"""
if self._supports_gradient_checkpointing:
self.apply(partial(self._set_gradient_checkpointing, value=False))
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],

View File

@ -3,12 +3,21 @@ from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
import torch.utils.checkpoint
from ..configuration_utils import ConfigMixin, register_to_config
from ..modeling_utils import ModelMixin
from ..utils import BaseOutput
from .embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
from .unet_blocks import (
CrossAttnDownBlock2D,
CrossAttnUpBlock2D,
DownBlock2D,
UNetMidBlock2DCrossAttn,
UpBlock2D,
get_down_block,
get_up_block,
)
@dataclass
@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
"""
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
@ -188,6 +199,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
if hasattr(block, "attentions") and block.attentions is not None:
block.set_attention_slice(slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
@ -234,7 +249,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
)
else:
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)

View File

@ -527,6 +527,8 @@ class CrossAttnDownBlock2D(nn.Module):
else:
self.downsamplers = None
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
@ -546,8 +548,22 @@ class CrossAttnDownBlock2D(nn.Module):
output_states = ()
for resnet, attn in zip(self.resnets, self.attentions):
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
@ -609,11 +625,24 @@ class DownBlock2D(nn.Module):
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None):
output_states = ()
for resnet in self.resnets:
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
output_states += (hidden_states,)
if self.downsamplers is not None:
@ -1072,6 +1101,8 @@ class CrossAttnUpBlock2D(nn.Module):
else:
self.upsamplers = None
self.gradient_checkpointing = False
def set_attention_slice(self, slice_size):
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
raise ValueError(
@ -1087,15 +1118,36 @@ class CrossAttnUpBlock2D(nn.Module):
for attn in self.attentions:
attn._set_attention_slice(slice_size)
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
):
for resnet, attn in zip(self.resnets, self.attentions):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn), hidden_states, encoder_hidden_states
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, context=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
@ -1150,6 +1202,8 @@ class UpBlock2D(nn.Module):
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
for resnet in self.resnets:
# pop res hidden states
@ -1157,7 +1211,17 @@ class UpBlock2D(nn.Module):
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
hidden_states = resnet(hidden_states, temb)
if self.training and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
else:
hidden_states = resnet(hidden_states, temb)
if self.upsamplers is not None:
for upsampler in self.upsamplers:

View File

@ -478,7 +478,7 @@ class LDMBertEncoderLayer(nn.Module):
class LDMBertPreTrainedModel(PreTrainedModel):
config_class = LDMBertConfig
base_model_prefix = "model"
supports_gradient_checkpointing = True
_supports_gradient_checkpointing = True
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
def _init_weights(self, module):

View File

@ -246,3 +246,21 @@ class ModelTesterMixin:
outputs_tuple = model(**inputs_dict, return_dict=False)
recursive_check(outputs_tuple, outputs_dict)
def test_enable_disable_gradient_checkpointing(self):
if not self.model_class._supports_gradient_checkpointing:
return # Skip test if model does not support gradient checkpointing
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
# at init model should have gradient checkpointing disabled
model = self.model_class(**init_dict)
self.assertFalse(model.is_gradient_checkpointing)
# check enable works
model.enable_gradient_checkpointing()
self.assertTrue(model.is_gradient_checkpointing)
# check disable works
model.disable_gradient_checkpointing()
self.assertFalse(model.is_gradient_checkpointing)

View File

@ -18,7 +18,7 @@ import unittest
import torch
from diffusers import UNet2DModel
from diffusers import UNet2DConditionModel, UNet2DModel
from diffusers.testing_utils import floats_tensor, slow, torch_device
from .test_modeling_common import ModelTesterMixin
@ -159,6 +159,82 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
model_class = UNet2DConditionModel
@property
def dummy_input(self):
batch_size = 4
num_channels = 4
sizes = (32, 32)
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
time_step = torch.tensor([10]).to(torch_device)
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
@property
def input_shape(self):
return (4, 32, 32)
@property
def output_shape(self):
return (4, 32, 32)
def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"block_out_channels": (32, 64),
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
"cross_attention_dim": 32,
"attention_head_dim": 8,
"out_channels": 4,
"in_channels": 4,
"layers_per_block": 2,
"sample_size": 32,
}
inputs_dict = self.dummy_input
return init_dict, inputs_dict
def test_gradient_checkpointing(self):
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
model = self.model_class(**init_dict)
model.to(torch_device)
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_not_checkpointed = out.data.clone()
grad_not_checkpointed = {}
for name, param in model.named_parameters():
grad_not_checkpointed[name] = param.grad.data.clone()
model.enable_gradient_checkpointing()
out = model(**inputs_dict).sample
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
# we won't calculate the loss and rather backprop on out.sum()
model.zero_grad()
out.sum().backward()
# now we save the output and parameter gradients that we will use for comparison purposes with
# the non-checkpointed run.
output_checkpointed = out.data.clone()
grad_checkpointed = {}
for name, param in model.named_parameters():
grad_checkpointed[name] = param.grad.data.clone()
# compare the output and parameters gradients
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
for name in grad_checkpointed:
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
# TODO(Patrick) - Re-add this test after having cleaned up LDM
# def test_output_pretrained_spatial_transformer(self):
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")