apply clip_grad_scale to the unscaled gradients just before stepping the scaler

This commit is contained in:
Damian Stewart 2023-05-21 20:10:18 +02:00 committed by Victor Hall
parent 562c434113
commit 1939cd52b7
1 changed files with 6 additions and 3 deletions

View File

@ -87,10 +87,13 @@ class EveryDreamOptimizer():
def step(self, loss, step, global_step):
self.scaler.scale(loss).backward()
if self.clip_grad_norm is not None:
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1):
if self.clip_grad_norm is not None:
for optimizer in self.optimizers:
self.scaler.unscale_(optimizer)
torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm)
torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm)
for optimizer in self.optimizers:
self.scaler.step(optimizer)