diff --git a/diffusers_trainer.py b/diffusers_trainer.py index f1c99aa..b3f6a2d 100644 --- a/diffusers_trainer.py +++ b/diffusers_trainer.py @@ -1,3 +1,6 @@ +# Example Usage: +# torchrun --nproc_per_node=2 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True + import argparse import socket import torch @@ -13,6 +16,7 @@ import psutil import pynvml import wandb import gc +import time import itertools import numpy as np @@ -31,6 +35,8 @@ from PIL import Image from typing import Dict, List, Generator, Tuple from scipy.interpolate import interp1d +torch.backends.cuda.matmul.allow_tf32 = True + # defaults should be good for everyone # TODO: add custom VAE support. should be simple with diffusers parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner') @@ -62,11 +68,21 @@ parser.add_argument('--image_log_steps', type=int, default=100, help='Number of parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') args = parser.parse_args() -os.makedirs(args.output_path, exist_ok=True) +def setup(): + torch.distributed.init_process_group("nccl", init_method="env://") -# remove hf_token from args so sneaky people don't steal it from the wandb logs -sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']} -run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb') +def cleanup(): + torch.distributed.destroy_process_group() + +def get_rank() -> int: + if not torch.distributed.is_initialized(): + return 0 + return torch.distributed.get_rank() + +def get_world_size() -> int: + if not torch.distributed.is_initialized(): + return 1 + return torch.distributed.get_world_size() # Inform the user of host, and various versions -- useful for debugging isseus. print("RUN_NAME:", args.run_name) @@ -322,18 +338,25 @@ class AspectBucket: self.bucket_data[bucket].append(index) + del entry + return True class AspectBucketSampler(torch.utils.data.Sampler): - def __init__(self, bucket: AspectBucket): + def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0): super().__init__(None) self.bucket = bucket + self.num_replicas = num_replicas + self.rank = rank def __iter__(self): - yield from self.bucket.get_batch_iterator() + # subsample the bucket to only include the elements that are assigned to this rank + indices = self.bucket.get_batch_iterator() + indices = list(indices)[self.rank::self.num_replicas] + return iter(indices) def __len__(self): - return self.bucket.get_batch_count() + return self.bucket.get_batch_count() // self.num_replicas class AspectDataset(torch.utils.data.Dataset): def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1): @@ -435,10 +458,21 @@ class EMAModel: ] def main(): - # get device. TODO: support multi-gpu - device = 'cpu' - if torch.cuda.is_available(): - device = 'cuda' + rank = get_rank() + world_size = get_world_size() + torch.cuda.set_device(rank) + + if args.hf_token is None: + args.hf_token = os.environ['HF_API_TOKEN'] + + if rank == 0: + os.makedirs(args.output_path, exist_ok=True) + + # remove hf_token from args so sneaky people don't steal it from the wandb logs + sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']} + run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb') + + device = torch.device('cuda') print("DEVICE:", device) @@ -489,18 +523,17 @@ def main(): # load dataset - store = ImageStore(args.dataset) dataset = AspectDataset(store, tokenizer) bucket = AspectBucket(store, 16, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0) - sampler = AspectBucketSampler(bucket) + sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank) print(f'STORE_LEN: {len(store)}') train_dataloader = torch.utils.data.DataLoader( dataset, batch_sampler=sampler, - num_workers=4, + num_workers=0, collate_fn=dataset.collate_fn ) @@ -516,6 +549,8 @@ def main(): unet = unet.to(device, dtype=torch.float32) text_encoder = text_encoder.to(device, dtype=weight_dtype) + #unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True) + # create ema if args.use_ema: ema_unet = EMAModel(unet.parameters()) @@ -527,27 +562,31 @@ def main(): global_step = 0 def save_checkpoint(): - if args.use_ema: - ema_unet.copy_to(unet.parameters()) - pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ) - pipeline.save_pretrained(args.output_path) - + if rank == 0: + if args.use_ema: + ema_unet.copy_to(unet.parameters()) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ) + pipeline.save_pretrained(args.output_path) + # barrier + torch.distributed.barrier() + # train! + loss = torch.tensor(0.0, device=device, dtype=weight_dtype) for epoch in range(args.epochs): unet.train() train_loss = 0.0 for step, batch in enumerate(train_dataloader): - # Convert images to latent space + b_start = time.perf_counter() latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 @@ -571,7 +610,7 @@ def main(): loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean") - # Backprop + # Backprop and all reduce scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() @@ -582,61 +621,72 @@ def main(): if args.use_ema: ema_unet.step(unet.parameters()) - progress_bar.update(1) - global_step += 1 - logs = { - "loss": loss.detach().item(), - "lr": lr_scheduler.get_last_lr()[0], - "epoch": epoch - } - progress_bar.set_postfix(logs) - run.log(logs) + # perf + b_end = time.perf_counter() + seconds_per_step = b_end - b_start + steps_per_second = 1 / seconds_per_step + rank_images_per_second = args.batch_size * steps_per_second + world_images_per_second = rank_images_per_second * world_size + samples_seen = global_step * args.batch_size * world_size + + # All reduce loss + torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM) + + if rank == 0: + progress_bar.update(1) + global_step += 1 + logs = { + "train/loss": loss.detach().item() / world_size, + "train/lr": lr_scheduler.get_last_lr()[0], + "train/epoch": epoch, + "train/samples_seen": samples_seen, + "perf/rank_samples_per_second": rank_images_per_second, + "perf/global_samples_per_second": world_images_per_second, + } + progress_bar.set_postfix(logs) + run.log(logs) if global_step % args.save_steps == 0: save_checkpoint() if global_step % args.image_log_steps == 0: - # get prompt from random batch - prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) - pipeline = StableDiffusionPipeline( - text_encoder=text_encoder, - vae=vae, - unet=unet, - tokenizer=tokenizer, - scheduler=PNDMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True - ), - safety_checker=None, # display safety checker to save memory - feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), - ).to(device) - # inference - images = [] - with torch.no_grad(): - with torch.autocast('cuda', enabled=args.fp16): - for _ in range(args.image_log_amount): - images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) - # log images under single caption - run.log({'images': images}) + if rank == 0: + # get prompt from random batch + prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) + pipeline = StableDiffusionPipeline( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=PNDMScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True + ), + safety_checker=None, # display safety checker to save memory + feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), + ).to(device) + # inference + images = [] + with torch.no_grad(): + with torch.autocast('cuda', enabled=args.fp16): + for _ in range(args.image_log_amount): + images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) + # log images under single caption + run.log({'images': images}) - # cleanup so we don't run out of memory - del pipeline - gc.collect() + # cleanup so we don't run out of memory + del pipeline + gc.collect() + torch.distributed.barrier() - save_checkpoint() + if rank == 0: + save_checkpoint() + + torch.distributed.barrier() + cleanup() print(get_gpu_ram()) print('Done!') - -if __name__ == '__main__': +if __name__ == "__main__": + setup() main() - -""" -import numpy as np -# save a sample -img = batch['pixel_values'][0].permute(1, 2, 0).cpu().numpy() -img = ((img + 1.0) * 127.5).astype(np.uint8) -img = Image.fromarray(img) -img.save('sample.png') -break -"""