diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 7711cf8..eb364f8 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -33,7 +33,7 @@ except pynvml.nvml.NVMLError_LibraryNotFound: pynvml = None from typing import Iterable -from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline +from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.optimization import get_scheduler from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer @@ -77,6 +77,8 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision') parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') +parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.') +parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') args = parser.parse_args() @@ -586,7 +588,8 @@ def main(): global_step = 0 if args.resume: - global_step = int(args.resume.split('_')[-1]) + target_global_step = int(args.resume.split('_')[-1]) + print(f'resuming from {args.resume}...') lr_scheduler = get_scheduler( args.lr_scheduler, @@ -611,6 +614,7 @@ def main(): safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) + print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}') pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}') # barrier torch.distributed.barrier() @@ -621,6 +625,11 @@ def main(): for epoch in range(args.epochs): unet.train() for _, batch in enumerate(train_dataloader): + if args.resume and global_step < target_global_step: + if rank == 0: + progress_bar.update(1) + global_step += 1 + continue b_start = time.perf_counter() latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 @@ -684,7 +693,7 @@ def main(): "perf/global_samples_per_second": world_images_per_second, } progress_bar.set_postfix(logs) - run.log(logs) + run.log(logs, step=global_step) if global_step % args.save_steps == 0: save_checkpoint(global_step) @@ -693,15 +702,26 @@ def main(): if rank == 0: # get prompt from random batch prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + + if args.image_log_scheduler == 'DDIMScheduler': + print('using DDIMScheduler scheduler') + scheduler = DDIMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + else: + print('using PNDMScheduler scheduler') + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ) + + scheduler.set_timesteps(num_inference_steps=35) pipeline = StableDiffusionPipeline( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=None, # display safety checker to save memory + scheduler=scheduler, + safety_checker=None, # disable safety checker to save memory feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ).to(device) # inference @@ -709,9 +729,14 @@ def main(): with torch.no_grad(): with torch.autocast('cuda', enabled=args.fp16): for _ in range(args.image_log_amount): - images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) + images.append( + wandb.Image(pipeline( + prompt, num_inference_steps=args.image_log_inference_steps + ).images[0], + caption=prompt) + ) # log images under single caption - run.log({'images': images}) + run.log({'images': images}, step=global_step) # cleanup so we don't run out of memory del pipeline