diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index eb45737..a7cd1b2 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -87,10 +87,13 @@ class EveryDreamOptimizer(): def step(self, loss, step, global_step): self.scaler.scale(loss).backward() - if self.clip_grad_norm is not None: - torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm) - torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm) if ((global_step + 1) % self.grad_accum == 0) or (step == self.epoch_len - 1): + if self.clip_grad_norm is not None: + for optimizer in self.optimizers: + self.scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(parameters=self.unet_params, max_norm=self.clip_grad_norm) + torch.nn.utils.clip_grad_norm_(parameters=self.text_encoder_params, max_norm=self.clip_grad_norm) + for optimizer in self.optimizers: self.scaler.step(optimizer)