From 4572617ff961209275791e2726481936e0d6ed53 Mon Sep 17 00:00:00 2001 From: Anthony Mercurio Date: Wed, 30 Nov 2022 10:54:30 -0700 Subject: [PATCH] use ddp for everything --- trainer/diffusers_trainer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 9345eaa..a2f9e3b 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -743,8 +743,7 @@ def main(): print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") exit(0) - dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True) - unet = dist_unet.module + unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) # create ema if args.use_ema: @@ -776,7 +775,7 @@ def main(): pipeline = StableDiffusionPipeline( text_encoder=text_encoder, vae=vae, - unet=unet, + unet=unet.module, tokenizer=tokenizer, scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), @@ -801,6 +800,7 @@ def main(): progress_bar.update(1) global_step += 1 continue + b_start = time.perf_counter() latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 @@ -825,7 +825,7 @@ def main(): # Predict the noise residual and compute loss with torch.autocast('cuda', enabled=args.fp16): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample if noise_scheduler.config.prediction_type == "epsilon": target = noise @@ -894,7 +894,7 @@ def main(): pipeline = StableDiffusionPipeline( text_encoder=text_encoder, vae=vae, - unet=unet, + unet=unet.module, tokenizer=tokenizer, scheduler=scheduler, safety_checker=None, # disable safety checker to save memory