[UNet2DConditionModel] add gradient checkpointing (#461)
* add grad ckpt to downsample blocks * make it work * don't pass gradient_checkpointing to upsample block * add tests for UNet2DConditionModel * add test_gradient_checkpointing * add gradient_checkpointing for up and down blocks * add functions to enable and disable grad ckpt * remove the forward argument * better naming * make supports_gradient_checkpointing private
This commit is contained in:
parent
534512bedb
commit
e7120bae95
|
@ -15,6 +15,7 @@
|
|||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from functools import partial
|
||||
from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
@ -121,10 +122,42 @@ class ModelMixin(torch.nn.Module):
|
|||
"""
|
||||
config_name = CONFIG_NAME
|
||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||
_supports_gradient_checkpointing = False
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@property
|
||||
def is_gradient_checkpointing(self) -> bool:
|
||||
"""
|
||||
Whether gradient checkpointing is activated for this model or not.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
||||
|
||||
def enable_gradient_checkpointing(self):
|
||||
"""
|
||||
Activates gradient checkpointing for the current model.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
if not self._supports_gradient_checkpointing:
|
||||
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||
|
||||
def disable_gradient_checkpointing(self):
|
||||
"""
|
||||
Deactivates gradient checkpointing for the current model.
|
||||
|
||||
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||
activations".
|
||||
"""
|
||||
if self._supports_gradient_checkpointing:
|
||||
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||
|
||||
def save_pretrained(
|
||||
self,
|
||||
save_directory: Union[str, os.PathLike],
|
||||
|
|
|
@ -3,12 +3,21 @@ from typing import Optional, Tuple, Union
|
|||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint
|
||||
|
||||
from ..configuration_utils import ConfigMixin, register_to_config
|
||||
from ..modeling_utils import ModelMixin
|
||||
from ..utils import BaseOutput
|
||||
from .embeddings import TimestepEmbedding, Timesteps
|
||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
||||
from .unet_blocks import (
|
||||
CrossAttnDownBlock2D,
|
||||
CrossAttnUpBlock2D,
|
||||
DownBlock2D,
|
||||
UNetMidBlock2DCrossAttn,
|
||||
UpBlock2D,
|
||||
get_down_block,
|
||||
get_up_block,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -188,6 +199,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
if hasattr(block, "attentions") and block.attentions is not None:
|
||||
block.set_attention_slice(slice_size)
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def forward(
|
||||
self,
|
||||
sample: torch.FloatTensor,
|
||||
|
@ -234,7 +249,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
|||
for downsample_block in self.down_blocks:
|
||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||
sample, res_samples = downsample_block(
|
||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
||||
hidden_states=sample,
|
||||
temb=emb,
|
||||
encoder_hidden_states=encoder_hidden_states,
|
||||
)
|
||||
else:
|
||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||
|
|
|
@ -527,6 +527,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
raise ValueError(
|
||||
|
@ -546,8 +548,22 @@ class CrossAttnDownBlock2D(nn.Module):
|
|||
output_states = ()
|
||||
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
|
@ -609,11 +625,24 @@ class DownBlock2D(nn.Module):
|
|||
else:
|
||||
self.downsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, temb=None):
|
||||
output_states = ()
|
||||
|
||||
for resnet in self.resnets:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
output_states += (hidden_states,)
|
||||
|
||||
if self.downsamplers is not None:
|
||||
|
@ -1072,6 +1101,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def set_attention_slice(self, slice_size):
|
||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||
raise ValueError(
|
||||
|
@ -1087,15 +1118,36 @@ class CrossAttnUpBlock2D(nn.Module):
|
|||
for attn in self.attentions:
|
||||
attn._set_attention_slice(slice_size)
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(
|
||||
self,
|
||||
hidden_states,
|
||||
res_hidden_states_tuple,
|
||||
temb=None,
|
||||
encoder_hidden_states=None,
|
||||
):
|
||||
for resnet, attn in zip(self.resnets, self.attentions):
|
||||
# pop res hidden states
|
||||
res_hidden_states = res_hidden_states_tuple[-1]
|
||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||
)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
@ -1150,6 +1202,8 @@ class UpBlock2D(nn.Module):
|
|||
else:
|
||||
self.upsamplers = None
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||
for resnet in self.resnets:
|
||||
# pop res hidden states
|
||||
|
@ -1157,7 +1211,17 @@ class UpBlock2D(nn.Module):
|
|||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||
else:
|
||||
hidden_states = resnet(hidden_states, temb)
|
||||
|
||||
if self.upsamplers is not None:
|
||||
for upsampler in self.upsamplers:
|
||||
|
|
|
@ -478,7 +478,7 @@ class LDMBertEncoderLayer(nn.Module):
|
|||
class LDMBertPreTrainedModel(PreTrainedModel):
|
||||
config_class = LDMBertConfig
|
||||
base_model_prefix = "model"
|
||||
supports_gradient_checkpointing = True
|
||||
_supports_gradient_checkpointing = True
|
||||
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
||||
|
||||
def _init_weights(self, module):
|
||||
|
|
|
@ -246,3 +246,21 @@ class ModelTesterMixin:
|
|||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||
|
||||
recursive_check(outputs_tuple, outputs_dict)
|
||||
|
||||
def test_enable_disable_gradient_checkpointing(self):
|
||||
if not self.model_class._supports_gradient_checkpointing:
|
||||
return # Skip test if model does not support gradient checkpointing
|
||||
|
||||
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||
|
||||
# at init model should have gradient checkpointing disabled
|
||||
model = self.model_class(**init_dict)
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
||||
# check enable works
|
||||
model.enable_gradient_checkpointing()
|
||||
self.assertTrue(model.is_gradient_checkpointing)
|
||||
|
||||
# check disable works
|
||||
model.disable_gradient_checkpointing()
|
||||
self.assertFalse(model.is_gradient_checkpointing)
|
||||
|
|
|
@ -18,7 +18,7 @@ import unittest
|
|||
|
||||
import torch
|
||||
|
||||
from diffusers import UNet2DModel
|
||||
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||
|
||||
from .test_modeling_common import ModelTesterMixin
|
||||
|
@ -159,6 +159,82 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
|||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||
|
||||
|
||||
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||
model_class = UNet2DConditionModel
|
||||
|
||||
@property
|
||||
def dummy_input(self):
|
||||
batch_size = 4
|
||||
num_channels = 4
|
||||
sizes = (32, 32)
|
||||
|
||||
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||
time_step = torch.tensor([10]).to(torch_device)
|
||||
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||
|
||||
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||
|
||||
@property
|
||||
def input_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
@property
|
||||
def output_shape(self):
|
||||
return (4, 32, 32)
|
||||
|
||||
def prepare_init_args_and_inputs_for_common(self):
|
||||
init_dict = {
|
||||
"block_out_channels": (32, 64),
|
||||
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||
"cross_attention_dim": 32,
|
||||
"attention_head_dim": 8,
|
||||
"out_channels": 4,
|
||||
"in_channels": 4,
|
||||
"layers_per_block": 2,
|
||||
"sample_size": 32,
|
||||
}
|
||||
inputs_dict = self.dummy_input
|
||||
return init_dict, inputs_dict
|
||||
|
||||
def test_gradient_checkpointing(self):
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_not_checkpointed = out.data.clone()
|
||||
grad_not_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_not_checkpointed[name] = param.grad.data.clone()
|
||||
|
||||
model.enable_gradient_checkpointing()
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
out.sum().backward()
|
||||
|
||||
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||
# the non-checkpointed run.
|
||||
output_checkpointed = out.data.clone()
|
||||
grad_checkpointed = {}
|
||||
for name, param in model.named_parameters():
|
||||
grad_checkpointed[name] = param.grad.data.clone()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
|
||||
for name in grad_checkpointed:
|
||||
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
|
||||
|
||||
|
||||
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
||||
# def test_output_pretrained_spatial_transformer(self):
|
||||
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
|
||||
|
|
Loading…
Reference in New Issue