From 3cefb57fc68f12459cfa95d36f5ccde7525e3f55 Mon Sep 17 00:00:00 2001 From: cafeai <116491182+cafeai@users.noreply.github.com> Date: Sat, 3 Dec 2022 19:42:50 +0900 Subject: [PATCH] fp32 Update --- trainer/diffusers_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 452f462..21f28e5 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -751,7 +751,7 @@ def main(): # move models to device vae = vae.to(device, dtype=weight_dtype) unet = unet.to(device, dtype=torch.float32) - text_encoder = text_encoder.to(device, dtype=weight_dtype) + text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32) unet = torch.nn.parallel.DistributedDataParallel( unet,