Get DDP to work
This commit is contained in:
parent
8decb0bc7d
commit
b0cec788be
|
@ -743,7 +743,8 @@ def main():
|
|||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||
exit(0)
|
||||
|
||||
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
||||
dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True)
|
||||
unet = dist_unet.module
|
||||
|
||||
# create ema
|
||||
if args.use_ema:
|
||||
|
@ -835,9 +836,7 @@ def main():
|
|||
|
||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||
|
||||
# All-reduce loss, backprop, and update weights
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||
loss = loss / world_size
|
||||
# backprop and update
|
||||
scaler.scale(loss).backward()
|
||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||
scaler.step(optimizer)
|
||||
|
@ -857,6 +856,10 @@ def main():
|
|||
world_images_per_second = rank_images_per_second * world_size
|
||||
samples_seen = global_step * args.batch_size * world_size
|
||||
|
||||
# get global loss for logging
|
||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||
loss = loss / world_size
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
|
|
Loading…
Reference in New Issue