diff --git a/diffusers_trainer.py b/diffusers_trainer.py index 8d20303..7711cf8 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -25,6 +25,7 @@ import itertools import numpy as np import json import re +import traceback try: pynvml.nvmlInit() @@ -47,6 +48,7 @@ torch.backends.cuda.matmul.allow_tf32 = True # TODO: add custom VAE support. should be simple with diffusers parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner') parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory') +parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.') parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.') parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.') parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.') @@ -63,6 +65,8 @@ parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1') parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2') parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay') parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon') +parser.add_argument('--lr_scheduler', type=str, default='cosine', help='Learning rate scheduler [`cosine`, `linear`, `constant`]') +parser.add_argument('--lr_scheduler_warmup', type=float, default=0.05, help='Learning rate scheduler warmup steps. This is a percentage of the total number of steps in the training run. 0.1 means 10 percent of the total number of steps.') parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.') parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.') parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.') @@ -93,17 +97,6 @@ def get_world_size() -> int: return 1 return torch.distributed.get_world_size() -# Inform the user of host, and various versions -- useful for debugging isseus. -print("RUN_NAME:", args.run_name) -print("HOST:", socket.gethostname()) -print("CUDA:", torch.version.cuda) -print("TORCH:", torch.__version__) -print("TRANSFORMERS:", transformers.__version__) -print("DIFFUSERS:", diffusers.__version__) -print("MODEL:", args.model) -print("FP16:", args.fp16) -print("RESOLUTION:", args.resolution) - def get_gpu_ram() -> str: """ Returns memory usage statistics for the CPU, GPU, and Torch. @@ -483,15 +476,24 @@ def main(): world_size = get_world_size() torch.cuda.set_device(rank) - if args.hf_token is None: - args.hf_token = os.environ['HF_API_TOKEN'] - if rank == 0: os.makedirs(args.output_path, exist_ok=True) + run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb') - # remove hf_token from args so sneaky people don't steal it from the wandb logs - sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']} - run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb') + # Inform the user of host, and various versions -- useful for debugging isseus. + print("RUN_NAME:", args.run_name) + print("HOST:", socket.gethostname()) + print("CUDA:", torch.version.cuda) + print("TORCH:", torch.__version__) + print("TRANSFORMERS:", transformers.__version__) + print("DIFFUSERS:", diffusers.__version__) + print("MODEL:", args.model) + print("FP16:", args.fp16) + print("RESOLUTION:", args.resolution) + + if args.hf_token is None: + args.hf_token = os.environ['HF_API_TOKEN'] + print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.') device = torch.device('cuda') @@ -504,6 +506,9 @@ def main(): torch.manual_seed(args.seed) print('RANDOM SEED:', args.seed) + if args.resume: + args.model = args.resume + tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token) text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token) vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token) @@ -561,11 +566,6 @@ def main(): collate_fn=dataset.collate_fn ) - lr_scheduler = get_scheduler( - 'constant', - optimizer=optimizer - ) - weight_dtype = torch.float16 if args.fp16 else torch.float32 # move models to device @@ -585,7 +585,18 @@ def main(): progress_bar = tqdm.tqdm(range(args.epochs * num_steps_per_epoch), desc="Total Steps", leave=False) global_step = 0 - def save_checkpoint(): + if args.resume: + global_step = int(args.resume.split('_')[-1]) + + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=int(args.lr_scheduler_warmup * num_steps_per_epoch * args.epochs), + num_training_steps=args.epochs * num_steps_per_epoch, + #last_epoch=(global_step // num_steps_per_epoch) - 1, + ) + + def save_checkpoint(global_step): if rank == 0: if args.use_ema: ema_unet.copy_to(unet.parameters()) @@ -600,114 +611,117 @@ def main(): safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), ) - pipeline.save_pretrained(args.output_path) + pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}') # barrier torch.distributed.barrier() # train! - loss = torch.tensor(0.0, device=device, dtype=weight_dtype) - for epoch in range(args.epochs): - unet.train() - train_loss = 0.0 - for step, batch in enumerate(train_dataloader): - b_start = time.perf_counter() - latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() - latents = latents * 0.18215 + try: + loss = torch.tensor(0.0, device=device, dtype=weight_dtype) + for epoch in range(args.epochs): + unet.train() + for _, batch in enumerate(train_dataloader): + b_start = time.perf_counter() + latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() + latents = latents * 0.18215 - # Sample noise - noise = torch.randn_like(latents) - bsz = latents.shape[0] - # Sample a random timestep for each image - timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) - timesteps = timesteps.long() + # Sample noise + noise = torch.randn_like(latents) + bsz = latents.shape[0] + # Sample a random timestep for each image + timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) + timesteps = timesteps.long() - # Add noise to the latents according to the noise magnitude at each timestep - # (this is the forward diffusion process) - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + # Add noise to the latents according to the noise magnitude at each timestep + # (this is the forward diffusion process) + noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) - # Get the text embedding for conditioning - encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True) - if args.clip_penultimate: - encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2]) - else: - encoder_hidden_states = encoder_hidden_states.last_hidden_state + # Get the text embedding for conditioning + encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True) + if args.clip_penultimate: + encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2]) + else: + encoder_hidden_states = encoder_hidden_states.last_hidden_state - # Predict the noise residual and compute loss - with torch.autocast('cuda', enabled=args.fp16): - noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample + # Predict the noise residual and compute loss + with torch.autocast('cuda', enabled=args.fp16): + noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") + loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") - # Backprop and all reduce - scaler.scale(loss).backward() - scaler.step(optimizer) - scaler.update() - lr_scheduler.step() - optimizer.zero_grad() + # Backprop and all reduce + scaler.scale(loss).backward() + scaler.step(optimizer) + scaler.update() + lr_scheduler.step() + optimizer.zero_grad() - # Update EMA - if args.use_ema: - ema_unet.step(unet.parameters()) + # Update EMA + if args.use_ema: + ema_unet.step(unet.parameters()) - # perf - b_end = time.perf_counter() - seconds_per_step = b_end - b_start - steps_per_second = 1 / seconds_per_step - rank_images_per_second = args.batch_size * steps_per_second - world_images_per_second = rank_images_per_second * world_size - samples_seen = global_step * args.batch_size * world_size + # perf + b_end = time.perf_counter() + seconds_per_step = b_end - b_start + steps_per_second = 1 / seconds_per_step + rank_images_per_second = args.batch_size * steps_per_second + world_images_per_second = rank_images_per_second * world_size + samples_seen = global_step * args.batch_size * world_size - # All reduce loss - torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) + # All reduce loss + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) - if rank == 0: - progress_bar.update(1) - global_step += 1 - logs = { - "train/loss": loss.detach().item() / world_size, - "train/lr": lr_scheduler.get_last_lr()[0], - "train/epoch": epoch, - "train/samples_seen": samples_seen, - "perf/rank_samples_per_second": rank_images_per_second, - "perf/global_samples_per_second": world_images_per_second, - } - progress_bar.set_postfix(logs) - run.log(logs) - - if global_step % args.save_steps == 0: - save_checkpoint() - - if global_step % args.image_log_steps == 0: if rank == 0: - # get prompt from random batch - prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) - pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=None, # display safety checker to save memory - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ).to(device) - # inference - images = [] - with torch.no_grad(): - with torch.autocast('cuda', enabled=args.fp16): - for _ in range(args.image_log_amount): - images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) - # log images under single caption - run.log({'images': images}) + progress_bar.update(1) + global_step += 1 + logs = { + "train/loss": loss.detach().item() / world_size, + "train/lr": lr_scheduler.get_last_lr()[0], + "train/epoch": epoch, + "train/step": global_step, + "train/samples_seen": samples_seen, + "perf/rank_samples_per_second": rank_images_per_second, + "perf/global_samples_per_second": world_images_per_second, + } + progress_bar.set_postfix(logs) + run.log(logs) - # cleanup so we don't run out of memory - del pipeline - gc.collect() - torch.distributed.barrier() + if global_step % args.save_steps == 0: + save_checkpoint(global_step) - if rank == 0: - save_checkpoint() + if global_step % args.image_log_steps == 0: + if rank == 0: + # get prompt from random batch + prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=None, # display safety checker to save memory + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ).to(device) + # inference + images = [] + with torch.no_grad(): + with torch.autocast('cuda', enabled=args.fp16): + for _ in range(args.image_log_amount): + images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) + # log images under single caption + run.log({'images': images}) + + # cleanup so we don't run out of memory + del pipeline + gc.collect() + torch.distributed.barrier() + except Exception as e: + print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}') + pass + + save_checkpoint(global_step) torch.distributed.barrier() cleanup()