Synchronize ranks for DDP

This commit is contained in:
Anthony Mercurio 2022-11-30 19:05:02 -07:00 committed by GitHub
parent 4572617ff9
commit 7102d313ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 43 additions and 28 deletions

View File

@ -78,7 +78,7 @@ parser.add_argument('--shuffle', dest='shuffle', type=bool_t, default='True', he
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
parser.add_argument('--fp16', dest='fp16', type=bool_t, default='False', help='Train in mixed precision')
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_steps', type=int, default=500, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
@ -691,6 +691,13 @@ def main():
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
unet = torch.nn.parallel.DistributedDataParallel(
unet,
device_ids=[rank],
output_device=rank,
gradient_as_bucket_view=True
)
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try:
import bitsandbytes as bnb
@ -701,6 +708,7 @@ def main():
else:
optimizer_cls = torch.optim.AdamW
"""
optimizer = optimizer_cls(
unet.parameters(),
lr=args.lr,
@ -708,13 +716,25 @@ def main():
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
"""
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
clip_sample=False
# Create distributed optimizer
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(
unet.parameters(),
optimizer_class=optimizer_cls,
parameters_as_bucket_view=True,
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
noise_scheduler = DDPMScheduler.from_pretrained(
args.model,
subfolder='scheduler',
use_auth_token=args.hf_token,
)
# load dataset
@ -743,8 +763,6 @@ def main():
print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.")
exit(0)
unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
# create ema
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
@ -786,8 +804,6 @@ def main():
if args.use_ema:
ema_unet.restore(unet.parameters())
# barrier
torch.distributed.barrier()
# train!
try:
@ -823,10 +839,6 @@ def main():
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample
if noise_scheduler.config.prediction_type == "epsilon":
target = noise
elif noise_scheduler.config.prediction_type == "v_prediction":
@ -834,6 +846,11 @@ def main():
else:
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
with unet.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
# backprop and update
@ -875,11 +892,10 @@ def main():
progress_bar.set_postfix(logs)
run.log(logs, step=global_step)
if global_step % args.save_steps == 0:
if global_step % args.save_steps == 0 and global_step > 0:
save_checkpoint(global_step)
if args.enableinference:
if global_step % args.image_log_steps == 0:
if global_step % args.image_log_steps == 0 and global_step > 0:
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
@ -935,7 +951,6 @@ def main():
# cleanup so we don't run out of memory
del pipeline
gc.collect()
torch.distributed.barrier()
except Exception as e:
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
pass