Get DDP to work
This commit is contained in:
parent
8decb0bc7d
commit
b0cec788be
|
@ -743,7 +743,8 @@ def main():
|
||||||
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
|
||||||
exit(0)
|
exit(0)
|
||||||
|
|
||||||
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
|
dist_unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], gradient_as_bucket_view=True)
|
||||||
|
unet = dist_unet.module
|
||||||
|
|
||||||
# create ema
|
# create ema
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
|
@ -835,9 +836,7 @@ def main():
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
# All-reduce loss, backprop, and update weights
|
# backprop and update
|
||||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
|
||||||
loss = loss / world_size
|
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
|
@ -857,6 +856,10 @@ def main():
|
||||||
world_images_per_second = rank_images_per_second * world_size
|
world_images_per_second = rank_images_per_second * world_size
|
||||||
samples_seen = global_step * args.batch_size * world_size
|
samples_seen = global_step * args.batch_size * world_size
|
||||||
|
|
||||||
|
# get global loss for logging
|
||||||
|
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||||
|
loss = loss / world_size
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
progress_bar.update(1)
|
progress_bar.update(1)
|
||||||
global_step += 1
|
global_step += 1
|
||||||
|
|
Loading…
Reference in New Issue