diff --git a/trainer/diffusers_trainer.py b/trainer/diffusers_trainer.py index 551171c..0b4685b 100644 --- a/trainer/diffusers_trainer.py +++ b/trainer/diffusers_trainer.py @@ -777,9 +777,7 @@ def main(): vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), + scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) @@ -828,10 +826,20 @@ def main(): with torch.autocast('cuda', enabled=args.fp16): noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + if noise_scheduler.config.prediction_type == "epsilon": + target = noise + elif noise_scheduler.config.prediction_type == "v_prediction": + target = noise_scheduler.get_velocity(latents, noise, timesteps) + else: + raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}") + + loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean") - # Backprop and all reduce + # All-reduce loss, backprop, and update weights + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) + loss = loss / world_size scaler.scale(loss).backward() + torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0) scaler.step(optimizer) scaler.update() lr_scheduler.step() @@ -849,14 +857,11 @@ def main(): world_images_per_second = rank_images_per_second * world_size samples_seen = global_step * args.batch_size * world_size - # All reduce loss - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) - if rank == 0: progress_bar.update(1) global_step += 1 logs = { - "train/loss": loss.detach().item() / world_size, + "train/loss": loss.detach().item(), "train/lr": lr_scheduler.get_last_lr()[0], "train/epoch": epoch, "train/step": global_step, @@ -878,14 +883,10 @@ def main(): if args.image_log_scheduler == 'DDIMScheduler': print('using DDIMScheduler scheduler') - scheduler = DDIMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) + scheduler = DDIMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token) else: print('using PNDMScheduler scheduler') - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ) + scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token) pipeline = StableDiffusionPipeline( text_encoder=text_encoder,