From b0cec788beaaadf75e97f8f5eb5083025ce7ee56 Mon Sep 17 00:00:00 2001 From: Anthony Mercurio Date: Tue, 29 Nov 2022 22:06:21 -0700 Subject: [PATCH] Get DDP to work --- trainer/diffusers_trainer.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 0b4685b..9345eaa 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -743,7 +743,8 @@ def main(): print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") exit(0) - #unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) + dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True) + unet = dist_unet.module # create ema if args.use_ema: @@ -835,9 +836,7 @@ def main(): loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - # All-reduce loss, backprop, and update weights - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) - loss = loss / world_size + # backprop and update scaler.scale(loss).backward() torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) scaler.step(optimizer) @@ -857,6 +856,10 @@ def main(): world_images_per_second = rank_images_per_second * world_size samples_seen = global_step * args.batch_size * world_size + # get global loss for logging + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) + loss = loss / world_size + if rank == 0: progress_bar.update(1) global_step += 1