From f2cfe65d09adcd092f7ef034045dc2d53f48418b Mon Sep 17 00:00:00 2001 From: Carlos Chavez <85657083+chavinlo@users.noreply.github.com> Date: Sun, 20 Nov 2022 00:09:35 -0500 Subject: [PATCH] Move the movel to device BEFORE creating the optimizer MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit >It shouldn’t matter, as the optimizer should hold the references to the parameter (even after moving them). However, the “safer” approach would be to move the model to the device first and create the optimizer afterwards. https://discuss.pytorch.org/t/should-i-create-optimizer-after-sending-the-model-to-gpu/133418/2 https://discuss.pytorch.org/t/effect-of-calling-model-cuda-after-constructing-an-optimizer/15165 At least in my experience with hivemind, if you initialize the optimizer and move the model afterwards, it will throw errors about finding some data in CPU and other on GPU. This shouldn't affect performance or anything I believe. --- trainer/diffusers_trainer.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 1a97c01..30cedb1 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -683,6 +683,14 @@ def main(): if args.use_xformers: unet.set_use_memory_efficient_attention_xformers(True) + # "The “safer” approach would be to move the model to the device first and create the optimizer afterwards." + weight_dtype = torch.float16 if args.fp16 else torch.float32 + + # move models to device + vae = vae.to(device, dtype=weight_dtype) + unet = unet.to(device, dtype=torch.float32) + text_encoder = text_encoder.to(device, dtype=weight_dtype) + if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: import bitsandbytes as bnb @@ -735,13 +743,6 @@ def main(): print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") exit(0) - weight_dtype = torch.float16 if args.fp16 else torch.float32 - - # move models to device - vae = vae.to(device, dtype=weight_dtype) - unet = unet.to(device, dtype=torch.float32) - text_encoder = text_encoder.to(device, dtype=weight_dtype) - #unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) # create ema