diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 64e4021..732a6b5 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -866,7 +866,7 @@ def main(): ema_unet.store(unet.parameters()) ema_unet.copy_to(unet.parameters()) pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, + text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module, vae=vae, unet=unet.module, tokenizer=tokenizer,