diff --git a/train.py b/train.py index 2e4fbed..59d7920 100644 --- a/train.py +++ b/train.py @@ -122,8 +122,8 @@ class EveryDreamTrainingState: self.tokenizer = tokenizer self.scheduler = scheduler self.vae = vae - self.unet_ema = unet_ema, - self.text_encoder = text_encoder_ema + self.unet_ema = unet_ema + self.text_encoder_ema = text_encoder_ema @torch.no_grad()