diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index a2f9e3b..781b3f1 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -78,7 +78,7 @@ parser.add_argument('--shuffle', dest='shuffle', type=bool_t, default='True', he parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.') parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB') parser.add_argument('--fp16', dest='fp16', type=bool_t, default='False', help='Train in mixed precision') -parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') +parser.add_argument('--image_log_steps', type=int, default=500, help='Number of steps to log images at.') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.') parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') @@ -690,6 +690,13 @@ def main(): vae = vae.to(device, dtype=weight_dtype) unet = unet.to(device, dtype=torch.float32) text_encoder = text_encoder.to(device, dtype=weight_dtype) + + unet = torch.nn.parallel.DistributedDataParallel( + unet, + device_ids=[rank], + output_device=rank, + gradient_as_bucket_view=True + ) if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. try: @@ -701,6 +708,7 @@ def main(): else: optimizer_cls = torch.optim.AdamW + """ optimizer = optimizer_cls( unet.parameters(), lr=args.lr, @@ -708,13 +716,25 @@ def main(): eps=args.adam_epsilon, weight_decay=args.adam_weight_decay, ) + """ - noise_scheduler = DDPMScheduler( - beta_start=0.00085, - beta_end=0.012, - beta_schedule='scaled_linear', - num_train_timesteps=1000, - clip_sample=False + # Create distributed optimizer + from torch.distributed.optim import ZeroRedundancyOptimizer + optimizer = ZeroRedundancyOptimizer( + unet.parameters(), + optimizer_class=optimizer_cls, + parameters_as_bucket_view=True, + lr=args.lr, + betas=(args.adam_beta1, args.adam_beta2), + eps=args.adam_epsilon, + weight_decay=args.adam_weight_decay, + ) + + + noise_scheduler = DDPMScheduler.from_pretrained( + args.model, + subfolder='scheduler', + use_auth_token=args.hf_token, ) # load dataset @@ -743,8 +763,6 @@ def main(): print(f"Completed resize and migration to '{args.dataset}_cropped' please relaunch the trainer without the --resize argument and train on the migrated dataset.") exit(0) - unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) - # create ema if args.use_ema: ema_unet = EMAModel(unet.parameters()) @@ -786,8 +804,6 @@ def main(): if args.use_ema: ema_unet.restore(unet.parameters()) - # barrier - torch.distributed.barrier() # train! try: @@ -823,26 +839,27 @@ def main(): else: encoder_hidden_states = encoder_hidden_states.last_hidden_state - # Predict the noise residual and compute loss - with torch.autocast('cuda', enabled=args.fp16): - noise_pred = unet.module(noisy_latents, timesteps, encoder_hidden_states).sample - if noise_scheduler.config.prediction_type == "epsilon": target = noise elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}") - - loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - # backprop and update - scaler.scale(loss).backward() - torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) - scaler.step(optimizer) - scaler.update() - lr_scheduler.step() - optimizer.zero_grad() + with unet.join(): + # Predict the noise residual and compute loss + with torch.autocast('cuda', enabled=args.fp16): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") + + # backprop and update + scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + optimizer.zero_grad() # Update EMA if args.use_ema: @@ -875,11 +892,10 @@ def main(): progress_bar.set_postfix(logs) run.log(logs, step=global_step) - if global_step % args.save_steps == 0: + if global_step % args.save_steps == 0 and global_step > 0: save_checkpoint(global_step) - if args.enableinference: - if global_step % args.image_log_steps == 0: + if global_step % args.image_log_steps == 0 and global_step > 0: if rank == 0: # get prompt from random batch prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) @@ -935,7 +951,6 @@ def main(): # cleanup so we don't run out of memory del pipeline gc.collect() - torch.distributed.barrier() except Exception as e: print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}') pass