Move the movel to device BEFORE creating the optimizer

>It shouldn’t matter, as the optimizer should hold the references to the parameter (even after moving them). However, the “safer” approach would be to move the model to the device first and create the optimizer afterwards.

https://discuss.pytorch.org/t/should-i-create-optimizer-after-sending-the-model-to-gpu/133418/2
https://discuss.pytorch.org/t/effect-of-calling-model-cuda-after-constructing-an-optimizer/15165

At least in my experience with hivemind, if you initialize the optimizer and move the model afterwards, it will throw errors about finding some data in CPU and other on GPU. This shouldn't affect performance or anything I believe.
This commit is contained in:
Carlos Chavez 2022-11-20 00:09:35 -05:00 committed by GitHub
parent 1d1f4022d2
commit f2cfe65d09
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 8 additions and 7 deletions

View File

@ -683,6 +683,14 @@ def main():
if args.use_xformers: if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True)
# "The “safer” approach would be to move the model to the device first and create the optimizer afterwards."
weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
@ -735,13 +743,6 @@ def main():
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0) exit(0)
weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) #unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
# create ema # create ema