Enabling gradient checkpointing for VAE (#2536)
* updated black format * update black format * make style format * updated line endings * update code formatting * Update examples/research_projects/onnxruntime/text_to_image/train_text_to_image.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * Update src/diffusers/models/vae.py Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> * added vae gradient checkpointing test * make style --------- Co-authored-by: Patrick von Platen <patrick.v.platen@gmail.com> Co-authored-by: Will Berman <wlbberman@gmail.com>
This commit is contained in:
parent
a16957159e
commit
116f70cbf8
|
@ -412,6 +412,7 @@ def main():
|
|||
|
||||
if args.gradient_checkpointing:
|
||||
unet.enable_gradient_checkpointing()
|
||||
vae.enable_gradient_checkpointing()
|
||||
|
||||
# Enable TF32 for faster training on Ampere GPUs,
|
||||
# cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
|
||||
|
|
|
@ -65,6 +65,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||
Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
|
||||
"""
|
||||
|
||||
_supports_gradient_checkpointing = True
|
||||
|
||||
@register_to_config
|
||||
def __init__(
|
||||
self,
|
||||
|
@ -121,6 +123,10 @@ class AutoencoderKL(ModelMixin, ConfigMixin):
|
|||
self.tile_latent_min_size = int(sample_size / (2 ** (len(self.block_out_channels) - 1)))
|
||||
self.tile_overlap_factor = 0.25
|
||||
|
||||
def _set_gradient_checkpointing(self, module, value=False):
|
||||
if isinstance(module, (Encoder, Decoder)):
|
||||
module.gradient_checkpointing = value
|
||||
|
||||
def enable_tiling(self, use_tiling: bool = True):
|
||||
r"""
|
||||
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
||||
|
|
|
@ -24,9 +24,7 @@ from .attention_processor import ( # noqa: F401
|
|||
SlicedAttnProcessor,
|
||||
XFormersAttnProcessor,
|
||||
)
|
||||
from .attention_processor import ( # noqa: F401
|
||||
AttnProcessor as AttnProcessorRename,
|
||||
)
|
||||
from .attention_processor import AttnProcessor as AttnProcessorRename # noqa: F401
|
||||
|
||||
|
||||
deprecate(
|
||||
|
|
|
@ -50,7 +50,13 @@ class Encoder(nn.Module):
|
|||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = torch.nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
self.conv_in = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[0],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.mid_block = None
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
@ -96,10 +102,28 @@ class Encoder(nn.Module):
|
|||
conv_out_channels = 2 * out_channels if double_z else out_channels
|
||||
self.conv_out = nn.Conv2d(block_out_channels[-1], conv_out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, x):
|
||||
sample = x
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(down_block), sample)
|
||||
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
|
||||
else:
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
sample = down_block(sample)
|
||||
|
@ -129,7 +153,13 @@ class Decoder(nn.Module):
|
|||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[-1], kernel_size=3, stride=1, padding=1)
|
||||
self.conv_in = nn.Conv2d(
|
||||
in_channels,
|
||||
block_out_channels[-1],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
)
|
||||
|
||||
self.mid_block = None
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
@ -176,10 +206,27 @@ class Decoder(nn.Module):
|
|||
self.conv_act = nn.SiLU()
|
||||
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, 3, padding=1)
|
||||
|
||||
self.gradient_checkpointing = False
|
||||
|
||||
def forward(self, z):
|
||||
sample = z
|
||||
sample = self.conv_in(sample)
|
||||
|
||||
if self.training and self.gradient_checkpointing:
|
||||
|
||||
def create_custom_forward(module):
|
||||
def custom_forward(*inputs):
|
||||
return module(*inputs)
|
||||
|
||||
return custom_forward
|
||||
|
||||
# middle
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(self.mid_block), sample)
|
||||
|
||||
# up
|
||||
for up_block in self.up_blocks:
|
||||
sample = torch.utils.checkpoint.checkpoint(create_custom_forward(up_block), sample)
|
||||
else:
|
||||
# middle
|
||||
sample = self.mid_block(sample)
|
||||
|
||||
|
|
|
@ -68,6 +68,47 @@ class AutoencoderKLTests(ModelTesterMixin, unittest.TestCase):
|
|||
def test_training(self):
|
||||
pass
|
||||
|
||||
@unittest.skipIf(torch_device == "mps", "Gradient checkpointing skipped on MPS")
|
||||
def test_gradient_checkpointing(self):
|
||||
# enable deterministic behavior for gradient checkpointing
|
||||
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
|
||||
model = self.model_class(**init_dict)
|
||||
model.to(torch_device)
|
||||
|
||||
assert not model.is_gradient_checkpointing and model.training
|
||||
|
||||
out = model(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model.zero_grad()
|
||||
|
||||
labels = torch.randn_like(out)
|
||||
loss = (out - labels).mean()
|
||||
loss.backward()
|
||||
|
||||
# re-instantiate the model now enabling gradient checkpointing
|
||||
model_2 = self.model_class(**init_dict)
|
||||
# clone model
|
||||
model_2.load_state_dict(model.state_dict())
|
||||
model_2.to(torch_device)
|
||||
model_2.enable_gradient_checkpointing()
|
||||
|
||||
assert model_2.is_gradient_checkpointing and model_2.training
|
||||
|
||||
out_2 = model_2(**inputs_dict).sample
|
||||
# run the backwards pass on the model. For backwards pass, for simplicity purpose,
|
||||
# we won't calculate the loss and rather backprop on out.sum()
|
||||
model_2.zero_grad()
|
||||
loss_2 = (out_2 - labels).mean()
|
||||
loss_2.backward()
|
||||
|
||||
# compare the output and parameters gradients
|
||||
self.assertTrue((loss - loss_2).abs() < 1e-5)
|
||||
named_params = dict(model.named_parameters())
|
||||
named_params_2 = dict(model_2.named_parameters())
|
||||
for name, param in named_params.items():
|
||||
self.assertTrue(torch_all_close(param.grad.data, named_params_2[name].grad.data, atol=5e-5))
|
||||
|
||||
def test_from_pretrained_hub(self):
|
||||
model, loading_info = AutoencoderKL.from_pretrained("fusing/autoencoder-kl-dummy", output_loading_info=True)
|
||||
self.assertIsNotNone(model)
|
||||
|
|
Loading…
Reference in New Issue