Move the movel to device BEFORE creating the optimizer
>It shouldn’t matter, as the optimizer should hold the references to the parameter (even after moving them). However, the “safer” approach would be to move the model to the device first and create the optimizer afterwards. https://discuss.pytorch.org/t/should-i-create-optimizer-after-sending-the-model-to-gpu/133418/2 https://discuss.pytorch.org/t/effect-of-calling-model-cuda-after-constructing-an-optimizer/15165 At least in my experience with hivemind, if you initialize the optimizer and move the model afterwards, it will throw errors about finding some data in CPU and other on GPU. This shouldn't affect performance or anything I believe.
This commit is contained in:
parent
1d1f4022d2
commit
f2cfe65d09
|
@ -683,6 +683,14 @@ def main():
|
|||
if args.use_xformers:
|
||||
unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# "The “safer” approach would be to move the model to the device first and create the optimizer afterwards."
|
||||
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||
|
||||
# move models to device
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
|
||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -735,13 +743,6 @@ def main():
|
|||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||
exit(0)
|
||||
|
||||
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||
|
||||
# move models to device
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
|
||||
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||
|
||||
# create ema
|
||||
|
|
Loading…
Reference in New Issue