Get DDP to work

This commit is contained in:
Anthony Mercurio 2022-11-29 22:06:21 -07:00 committed by GitHub
parent 8decb0bc7d
commit b0cec788be
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 7 additions and 4 deletions

View File

@ -743,7 +743,8 @@ def main():
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0) exit(0)
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True)
unet = dist_unet.module
# create ema # create ema
if args.use_ema: if args.use_ema:
@ -835,9 +836,7 @@ def main():
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
# All-reduce loss, backprop, and update weights # backprop and update
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss = loss / world_size
scaler.scale(loss).backward() scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
scaler.step(optimizer) scaler.step(optimizer)
@ -857,6 +856,10 @@ def main():
world_images_per_second = rank_images_per_second * world_size world_images_per_second = rank_images_per_second * world_size
samples_seen = global_step * args.batch_size * world_size samples_seen = global_step * args.batch_size * world_size
# get global loss for logging
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
loss = loss / world_size
if rank == 0: if rank == 0:
progress_bar.update(1) progress_bar.update(1)
global_step += 1 global_step += 1