Merge pull request #47 from chavinlo/patch-2
Move the model to device BEFORE creating the optimizer
This commit is contained in:
commit
511ee9e6d2
|
@ -683,6 +683,14 @@ def main():
|
|||
if args.use_xformers:
|
||||
unet.set_use_memory_efficient_attention_xformers(True)
|
||||
|
||||
# "The “safer” approach would be to move the model to the device first and create the optimizer afterwards."
|
||||
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||
|
||||
# move models to device
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
|
||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
|
@ -735,13 +743,6 @@ def main():
|
|||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||
exit(0)
|
||||
|
||||
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||
|
||||
# move models to device
|
||||
vae = vae.to(device, dtype=weight_dtype)
|
||||
unet = unet.to(device, dtype=torch.float32)
|
||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||
|
||||
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||
|
||||
# create ema
|
||||
|
|
Loading…
Reference in New Issue