diff --git a/ldm/modules/diffusionmodules/util.py b/ldm/modules/diffusionmodules/util.py index 6b5b9dc..3d25919 100644 --- a/ldm/modules/diffusionmodules/util.py +++ b/ldm/modules/diffusionmodules/util.py @@ -109,7 +109,7 @@ def checkpoint(func, inputs, params, flag): explicitly take as arguments. :param flag: if False, disable gradient checkpointing. """ - if False: # disabled checkpointing to allow requires_grad = False for main model + if flag: # disabled checkpointing to allow requires_grad = False for main model args = tuple(inputs) + tuple(params) return CheckpointFunction.apply(func, len(inputs), *args) else: @@ -264,4 +264,4 @@ class HybridConditioner(nn.Module): def noise_like(shape, device, repeat=False): repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) noise = lambda: torch.randn(shape, device=device) - return repeat_noise() if repeat else noise() \ No newline at end of file + return repeat_noise() if repeat else noise()