diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 1a97c01..30cedb1 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -683,6 +683,14 @@ def main(): if args.use_xformers: unet.set_use_memory_efficient_attention_xformers(True) + # "The “safer” approach would be to move the model to the device first and create the optimizer afterwards." + weight_dtype = torch.float16 if args.fp16 else torch.float32 + + # move models to device + vae = vae.to(device, dtype=weight_dtype) + unet = unet.to(device, dtype=torch.float32) + text_encoder = text_encoder.to(device, dtype=weight_dtype) + if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: import bitsandbytes as bnb @@ -735,13 +743,6 @@ def main(): print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") exit(0) - weight_dtype = torch.float16 if args.fp16 else torch.float32 - - # move models to device - vae = vae.to(device, dtype=weight_dtype) - unet = unet.to(device, dtype=torch.float32) - text_encoder = text_encoder.to(device, dtype=weight_dtype) - #unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) # create ema