apply clip_grad_scale to the unscaled gradients just before stepping the scaler
This commit is contained in:
parent
562c434113
commit
1939cd52b7
|
@ -87,10 +87,13 @@ class EveryDreamOptimizer():
|
|||
def step(self, loss, step, global_step):
|
||||
self.scaler.scale(loss).backward()
|
||||
|
||||
if self.clip_grad_norm is not None:
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||
if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
|
||||
if self.clip_grad_norm is not None:
|
||||
for optimizer in self.optimizers:
|
||||
self.scaler.unscale_(optimizer)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
|
||||
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
|
||||
|
||||
for optimizer in self.optimizers:
|
||||
self.scaler.step(optimizer)
|
||||
|
||||
|
|
Loading…
Reference in New Issue