diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 452f462..21f28e5 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -751,7 +751,7 @@ def main(): # move models to device vae = vae.to(device, dtype=weight_dtype) unet = unet.to(device, dtype=torch.float32) - text_encoder = text_encoder.to(device, dtype=weight_dtype) + text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32) unet = torch.nn.parallel.DistributedDataParallel( unet,