[UNet2DConditionModel] add gradient checkpointing (#461)
* add grad ckpt to downsample blocks * make it work * don't pass gradient_checkpointing to upsample block * add tests for UNet2DConditionModel * add test_gradient_checkpointing * add gradient_checkpointing for up and down blocks * add functions to enable and disable grad ckpt * remove the forward argument * better naming * make supports_gradient_checkpointing private
This commit is contained in:
parent
534512bedb
commit
e7120bae95
|
@ -15,6 +15,7 @@
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from functools import partial
|
||||||
from typing import Callable, List, Optional, Tuple, Union
|
from typing import Callable, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
@ -121,10 +122,42 @@ class ModelMixin(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
config_name = CONFIG_NAME
|
config_name = CONFIG_NAME
|
||||||
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
||||||
|
_supports_gradient_checkpointing = False
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_gradient_checkpointing(self) -> bool:
|
||||||
|
"""
|
||||||
|
Whether gradient checkpointing is activated for this model or not.
|
||||||
|
|
||||||
|
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||||
|
activations".
|
||||||
|
"""
|
||||||
|
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
||||||
|
|
||||||
|
def enable_gradient_checkpointing(self):
|
||||||
|
"""
|
||||||
|
Activates gradient checkpointing for the current model.
|
||||||
|
|
||||||
|
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||||
|
activations".
|
||||||
|
"""
|
||||||
|
if not self._supports_gradient_checkpointing:
|
||||||
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
||||||
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
|
|
||||||
|
def disable_gradient_checkpointing(self):
|
||||||
|
"""
|
||||||
|
Deactivates gradient checkpointing for the current model.
|
||||||
|
|
||||||
|
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
||||||
|
activations".
|
||||||
|
"""
|
||||||
|
if self._supports_gradient_checkpointing:
|
||||||
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
||||||
|
|
||||||
def save_pretrained(
|
def save_pretrained(
|
||||||
self,
|
self,
|
||||||
save_directory: Union[str, os.PathLike],
|
save_directory: Union[str, os.PathLike],
|
||||||
|
|
|
@ -3,12 +3,21 @@ from typing import Optional, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import torch.utils.checkpoint
|
||||||
|
|
||||||
from ..configuration_utils import ConfigMixin, register_to_config
|
from ..configuration_utils import ConfigMixin, register_to_config
|
||||||
from ..modeling_utils import ModelMixin
|
from ..modeling_utils import ModelMixin
|
||||||
from ..utils import BaseOutput
|
from ..utils import BaseOutput
|
||||||
from .embeddings import TimestepEmbedding, Timesteps
|
from .embeddings import TimestepEmbedding, Timesteps
|
||||||
from .unet_blocks import UNetMidBlock2DCrossAttn, get_down_block, get_up_block
|
from .unet_blocks import (
|
||||||
|
CrossAttnDownBlock2D,
|
||||||
|
CrossAttnUpBlock2D,
|
||||||
|
DownBlock2D,
|
||||||
|
UNetMidBlock2DCrossAttn,
|
||||||
|
UpBlock2D,
|
||||||
|
get_down_block,
|
||||||
|
get_up_block,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -54,6 +63,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||||
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
_supports_gradient_checkpointing = True
|
||||||
|
|
||||||
@register_to_config
|
@register_to_config
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
|
@ -188,6 +199,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||||
if hasattr(block, "attentions") and block.attentions is not None:
|
if hasattr(block, "attentions") and block.attentions is not None:
|
||||||
block.set_attention_slice(slice_size)
|
block.set_attention_slice(slice_size)
|
||||||
|
|
||||||
|
def _set_gradient_checkpointing(self, module, value=False):
|
||||||
|
if isinstance(module, (CrossAttnDownBlock2D, DownBlock2D, CrossAttnUpBlock2D, UpBlock2D)):
|
||||||
|
module.gradient_checkpointing = value
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
|
@ -234,7 +249,9 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin):
|
||||||
for downsample_block in self.down_blocks:
|
for downsample_block in self.down_blocks:
|
||||||
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
if hasattr(downsample_block, "attentions") and downsample_block.attentions is not None:
|
||||||
sample, res_samples = downsample_block(
|
sample, res_samples = downsample_block(
|
||||||
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
|
hidden_states=sample,
|
||||||
|
temb=emb,
|
||||||
|
encoder_hidden_states=encoder_hidden_states,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
|
||||||
|
|
|
@ -527,6 +527,8 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.downsamplers = None
|
self.downsamplers = None
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def set_attention_slice(self, slice_size):
|
def set_attention_slice(self, slice_size):
|
||||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -546,8 +548,22 @@ class CrossAttnDownBlock2D(nn.Module):
|
||||||
output_states = ()
|
output_states = ()
|
||||||
|
|
||||||
for resnet, attn in zip(self.resnets, self.attentions):
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
hidden_states = resnet(hidden_states, temb)
|
if self.training and self.gradient_checkpointing:
|
||||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||||
|
|
||||||
output_states += (hidden_states,)
|
output_states += (hidden_states,)
|
||||||
|
|
||||||
if self.downsamplers is not None:
|
if self.downsamplers is not None:
|
||||||
|
@ -609,11 +625,24 @@ class DownBlock2D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.downsamplers = None
|
self.downsamplers = None
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(self, hidden_states, temb=None):
|
def forward(self, hidden_states, temb=None):
|
||||||
output_states = ()
|
output_states = ()
|
||||||
|
|
||||||
for resnet in self.resnets:
|
for resnet in self.resnets:
|
||||||
hidden_states = resnet(hidden_states, temb)
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
output_states += (hidden_states,)
|
output_states += (hidden_states,)
|
||||||
|
|
||||||
if self.downsamplers is not None:
|
if self.downsamplers is not None:
|
||||||
|
@ -1072,6 +1101,8 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.upsamplers = None
|
self.upsamplers = None
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def set_attention_slice(self, slice_size):
|
def set_attention_slice(self, slice_size):
|
||||||
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
if slice_size is not None and self.attn_num_head_channels % slice_size != 0:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -1087,15 +1118,36 @@ class CrossAttnUpBlock2D(nn.Module):
|
||||||
for attn in self.attentions:
|
for attn in self.attentions:
|
||||||
attn._set_attention_slice(slice_size)
|
attn._set_attention_slice(slice_size)
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None, encoder_hidden_states=None):
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
def forward(
|
||||||
|
self,
|
||||||
|
hidden_states,
|
||||||
|
res_hidden_states_tuple,
|
||||||
|
temb=None,
|
||||||
|
encoder_hidden_states=None,
|
||||||
|
):
|
||||||
for resnet, attn in zip(self.resnets, self.attentions):
|
for resnet, attn in zip(self.resnets, self.attentions):
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
res_hidden_states = res_hidden_states_tuple[-1]
|
res_hidden_states = res_hidden_states_tuple[-1]
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
hidden_states = resnet(hidden_states, temb)
|
if self.training and self.gradient_checkpointing:
|
||||||
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
||||||
|
create_custom_forward(attn), hidden_states, encoder_hidden_states
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
hidden_states = attn(hidden_states, context=encoder_hidden_states)
|
||||||
|
|
||||||
if self.upsamplers is not None:
|
if self.upsamplers is not None:
|
||||||
for upsampler in self.upsamplers:
|
for upsampler in self.upsamplers:
|
||||||
|
@ -1150,6 +1202,8 @@ class UpBlock2D(nn.Module):
|
||||||
else:
|
else:
|
||||||
self.upsamplers = None
|
self.upsamplers = None
|
||||||
|
|
||||||
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
def forward(self, hidden_states, res_hidden_states_tuple, temb=None):
|
||||||
for resnet in self.resnets:
|
for resnet in self.resnets:
|
||||||
# pop res hidden states
|
# pop res hidden states
|
||||||
|
@ -1157,7 +1211,17 @@ class UpBlock2D(nn.Module):
|
||||||
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
||||||
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
||||||
|
|
||||||
hidden_states = resnet(hidden_states, temb)
|
if self.training and self.gradient_checkpointing:
|
||||||
|
|
||||||
|
def create_custom_forward(module):
|
||||||
|
def custom_forward(*inputs):
|
||||||
|
return module(*inputs)
|
||||||
|
|
||||||
|
return custom_forward
|
||||||
|
|
||||||
|
hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
|
||||||
|
else:
|
||||||
|
hidden_states = resnet(hidden_states, temb)
|
||||||
|
|
||||||
if self.upsamplers is not None:
|
if self.upsamplers is not None:
|
||||||
for upsampler in self.upsamplers:
|
for upsampler in self.upsamplers:
|
||||||
|
|
|
@ -478,7 +478,7 @@ class LDMBertEncoderLayer(nn.Module):
|
||||||
class LDMBertPreTrainedModel(PreTrainedModel):
|
class LDMBertPreTrainedModel(PreTrainedModel):
|
||||||
config_class = LDMBertConfig
|
config_class = LDMBertConfig
|
||||||
base_model_prefix = "model"
|
base_model_prefix = "model"
|
||||||
supports_gradient_checkpointing = True
|
_supports_gradient_checkpointing = True
|
||||||
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
_keys_to_ignore_on_load_unexpected = [r"encoder\.version", r"decoder\.version"]
|
||||||
|
|
||||||
def _init_weights(self, module):
|
def _init_weights(self, module):
|
||||||
|
|
|
@ -246,3 +246,21 @@ class ModelTesterMixin:
|
||||||
outputs_tuple = model(**inputs_dict, return_dict=False)
|
outputs_tuple = model(**inputs_dict, return_dict=False)
|
||||||
|
|
||||||
recursive_check(outputs_tuple, outputs_dict)
|
recursive_check(outputs_tuple, outputs_dict)
|
||||||
|
|
||||||
|
def test_enable_disable_gradient_checkpointing(self):
|
||||||
|
if not self.model_class._supports_gradient_checkpointing:
|
||||||
|
return # Skip test if model does not support gradient checkpointing
|
||||||
|
|
||||||
|
init_dict, _ = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
|
||||||
|
# at init model should have gradient checkpointing disabled
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
self.assertFalse(model.is_gradient_checkpointing)
|
||||||
|
|
||||||
|
# check enable works
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
self.assertTrue(model.is_gradient_checkpointing)
|
||||||
|
|
||||||
|
# check disable works
|
||||||
|
model.disable_gradient_checkpointing()
|
||||||
|
self.assertFalse(model.is_gradient_checkpointing)
|
||||||
|
|
|
@ -18,7 +18,7 @@ import unittest
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from diffusers import UNet2DModel
|
from diffusers import UNet2DConditionModel, UNet2DModel
|
||||||
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
from diffusers.testing_utils import floats_tensor, slow, torch_device
|
||||||
|
|
||||||
from .test_modeling_common import ModelTesterMixin
|
from .test_modeling_common import ModelTesterMixin
|
||||||
|
@ -159,6 +159,82 @@ class UNetLDMModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
self.assertTrue(torch.allclose(output_slice, expected_output_slice, rtol=1e-3))
|
||||||
|
|
||||||
|
|
||||||
|
class UNet2DConditionModelTests(ModelTesterMixin, unittest.TestCase):
|
||||||
|
model_class = UNet2DConditionModel
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dummy_input(self):
|
||||||
|
batch_size = 4
|
||||||
|
num_channels = 4
|
||||||
|
sizes = (32, 32)
|
||||||
|
|
||||||
|
noise = floats_tensor((batch_size, num_channels) + sizes).to(torch_device)
|
||||||
|
time_step = torch.tensor([10]).to(torch_device)
|
||||||
|
encoder_hidden_states = floats_tensor((batch_size, 4, 32)).to(torch_device)
|
||||||
|
|
||||||
|
return {"sample": noise, "timestep": time_step, "encoder_hidden_states": encoder_hidden_states}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def input_shape(self):
|
||||||
|
return (4, 32, 32)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def output_shape(self):
|
||||||
|
return (4, 32, 32)
|
||||||
|
|
||||||
|
def prepare_init_args_and_inputs_for_common(self):
|
||||||
|
init_dict = {
|
||||||
|
"block_out_channels": (32, 64),
|
||||||
|
"down_block_types": ("CrossAttnDownBlock2D", "DownBlock2D"),
|
||||||
|
"up_block_types": ("UpBlock2D", "CrossAttnUpBlock2D"),
|
||||||
|
"cross_attention_dim": 32,
|
||||||
|
"attention_head_dim": 8,
|
||||||
|
"out_channels": 4,
|
||||||
|
"in_channels": 4,
|
||||||
|
"layers_per_block": 2,
|
||||||
|
"sample_size": 32,
|
||||||
|
}
|
||||||
|
inputs_dict = self.dummy_input
|
||||||
|
return init_dict, inputs_dict
|
||||||
|
|
||||||
|
def test_gradient_checkpointing(self):
|
||||||
|
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||||
|
model = self.model_class(**init_dict)
|
||||||
|
model.to(torch_device)
|
||||||
|
|
||||||
|
out = model(**inputs_dict).sample
|
||||||
|
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||||
|
# we won't calculate the loss and rather backprop on out.sum()
|
||||||
|
model.zero_grad()
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||||
|
# the non-checkpointed run.
|
||||||
|
output_not_checkpointed = out.data.clone()
|
||||||
|
grad_not_checkpointed = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
grad_not_checkpointed[name] = param.grad.data.clone()
|
||||||
|
|
||||||
|
model.enable_gradient_checkpointing()
|
||||||
|
out = model(**inputs_dict).sample
|
||||||
|
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||||
|
# we won't calculate the loss and rather backprop on out.sum()
|
||||||
|
model.zero_grad()
|
||||||
|
out.sum().backward()
|
||||||
|
|
||||||
|
# now we save the output and parameter gradients that we will use for comparison purposes with
|
||||||
|
# the non-checkpointed run.
|
||||||
|
output_checkpointed = out.data.clone()
|
||||||
|
grad_checkpointed = {}
|
||||||
|
for name, param in model.named_parameters():
|
||||||
|
grad_checkpointed[name] = param.grad.data.clone()
|
||||||
|
|
||||||
|
# compare the output and parameters gradients
|
||||||
|
self.assertTrue((output_checkpointed == output_not_checkpointed).all())
|
||||||
|
for name in grad_checkpointed:
|
||||||
|
self.assertTrue(torch.allclose(grad_checkpointed[name], grad_not_checkpointed[name], atol=5e-5))
|
||||||
|
|
||||||
|
|
||||||
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
# TODO(Patrick) - Re-add this test after having cleaned up LDM
|
||||||
# def test_output_pretrained_spatial_transformer(self):
|
# def test_output_pretrained_spatial_transformer(self):
|
||||||
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
|
# model = UNetLDMModel.from_pretrained("fusing/unet-ldm-dummy-spatial")
|
||||||
|
|
Loading…
Reference in New Issue