Merge pull request #47 from chavinlo/patch-2
Move the model to device BEFORE creating the optimizer
This commit is contained in:
commit
511ee9e6d2
|
@ -683,6 +683,14 @@ def main():
|
||||||
if args.use_xformers:
|
if args.use_xformers:
|
||||||
unet.set_use_memory_efficient_attention_xformers(True)
|
unet.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
|
# "The “safer” approach would be to move the model to the device first and create the optimizer afterwards."
|
||||||
|
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||||
|
|
||||||
|
# move models to device
|
||||||
|
vae = vae.to(device, dtype=weight_dtype)
|
||||||
|
unet = unet.to(device, dtype=torch.float32)
|
||||||
|
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
||||||
|
|
||||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -735,13 +743,6 @@ def main():
|
||||||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
|
||||||
|
|
||||||
# move models to device
|
|
||||||
vae = vae.to(device, dtype=weight_dtype)
|
|
||||||
unet = unet.to(device, dtype=torch.float32)
|
|
||||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
|
||||||
|
|
||||||
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||||
|
|
||||||
# create ema
|
# create ema
|
||||||
|
|
Loading…
Reference in New Issue