Update util.py
This commit is contained in:
parent
bb8f4f2dc1
commit
a22c61800d
|
@ -109,7 +109,7 @@ def checkpoint(func, inputs, params, flag):
|
||||||
explicitly take as arguments.
|
explicitly take as arguments.
|
||||||
:param flag: if False, disable gradient checkpointing.
|
:param flag: if False, disable gradient checkpointing.
|
||||||
"""
|
"""
|
||||||
if False: # disabled checkpointing to allow requires_grad = False for main model
|
if flag: # disabled checkpointing to allow requires_grad = False for main model
|
||||||
args = tuple(inputs) + tuple(params)
|
args = tuple(inputs) + tuple(params)
|
||||||
return CheckpointFunction.apply(func, len(inputs), *args)
|
return CheckpointFunction.apply(func, len(inputs), *args)
|
||||||
else:
|
else:
|
||||||
|
@ -264,4 +264,4 @@ class HybridConditioner(nn.Module):
|
||||||
def noise_like(shape, device, repeat=False):
|
def noise_like(shape, device, repeat=False):
|
||||||
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
|
||||||
noise = lambda: torch.randn(shape, device=device)
|
noise = lambda: torch.randn(shape, device=device)
|
||||||
return repeat_noise() if repeat else noise()
|
return repeat_noise() if repeat else noise()
|
||||||
|
|
Loading…
Reference in New Issue