Update util.py

This commit is contained in:
Xavier 2022-09-20 21:48:16 -07:00 committed by GitHub
parent bb8f4f2dc1
commit a22c61800d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 2 deletions

View File

@ -109,7 +109,7 @@ def checkpoint(func, inputs, params, flag):
explicitly take as arguments. explicitly take as arguments.
:param flag: if False, disable gradient checkpointing. :param flag: if False, disable gradient checkpointing.
""" """
if False: # disabled checkpointing to allow requires_grad = False for main model if flag: # disabled checkpointing to allow requires_grad = False for main model
args = tuple(inputs) + tuple(params) args = tuple(inputs) + tuple(params)
return CheckpointFunction.apply(func, len(inputs), *args) return CheckpointFunction.apply(func, len(inputs), *args)
else: else:
@ -264,4 +264,4 @@ class HybridConditioner(nn.Module):
def noise_like(shape, device, repeat=False): def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device) noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise() return repeat_noise() if repeat else noise()