diff --git a/optimizer/optimizers.py b/optimizer/optimizers.py index a740289..feda9a1 100644 --- a/optimizer/optimizers.py +++ b/optimizer/optimizers.py @@ -122,10 +122,10 @@ class EveryDreamOptimizer(): """ te_optimizer_state_path = os.path.join(ckpt_path, OPTIMIZER_TE_STATE_FILENAME) unet_optimizer_state_path = os.path.join(ckpt_path, OPTIMIZER_UNET_STATE_FILENAME) - if os.path.exists(te_optimizer_state_path) and self.optimizer_unet is not None: - self._load_optimizer(self.optimizer_unet, te_optimizer_state_path) - if os.path.exists(unet_optimizer_state_path) and self.optimizer_te is not None: - self._load_optimizer(self.optimizer_te, unet_optimizer_state_path) + if os.path.exists(te_optimizer_state_path) and self.optimizer_te is not None: + self._load_optimizer(self.optimizer_te, te_optimizer_state_path) + if os.path.exists(unet_optimizer_state_path) and self.optimizer_unet is not None: + self._load_optimizer(self.optimizer_unet, unet_optimizer_state_path) def create_optimizers(self, args, text_encoder_params, unet_params): """