From 81bdbb5e2a60fdfd4ee3c25ee2619adcd60f96ca Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Henrik=20Forst=C3=A9n?= Date: Mon, 10 Oct 2022 22:29:27 +0300 Subject: [PATCH] DreamBooth DeepSpeed support for under 8 GB VRAM training (#735) * Support deepspeed * Dreambooth DeepSpeed documentation * Remove unnecessary casts, documentation Due to recent commits some casts to half precision are not necessary anymore. Mention that DeepSpeed's version of Adam is about 2x faster. * Review comments --- examples/dreambooth/README.md | 40 +++++++++++++++++++++++++ examples/dreambooth/train_dreambooth.py | 22 +++++++++----- 2 files changed, 55 insertions(+), 7 deletions(-) diff --git a/examples/dreambooth/README.md b/examples/dreambooth/README.md index 6770133a..9ff90ea8 100644 --- a/examples/dreambooth/README.md +++ b/examples/dreambooth/README.md @@ -119,6 +119,46 @@ accelerate launch train_dreambooth.py \ --max_train_steps=800 ``` +### Training on a 8 GB GPU: + +By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some +tensors from VRAM to either CPU or NVME allowing to train with less VRAM. + +DeepSpeed needs to be enabled with `accelerate config`. During configuration +answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16 +mixed precision and offloading both parameters and optimizer state to cpu it's +possible to train on under 8 GB VRAM with a drawback of requiring significantly +more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options. + +Changing the default Adam optimizer to DeepSpeed's special version of Adam +`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling +it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer +does not seem to be compatible with DeepSpeed at the moment. + +```bash +export MODEL_NAME="CompVis/stable-diffusion-v1-4" +export INSTANCE_DIR="path-to-instance-images" +export CLASS_DIR="path-to-class-images" +export OUTPUT_DIR="path-to-save-model" + +accelerate launch train_dreambooth.py \ + --pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \ + --instance_data_dir=$INSTANCE_DIR \ + --class_data_dir=$CLASS_DIR \ + --output_dir=$OUTPUT_DIR \ + --with_prior_preservation --prior_loss_weight=1.0 \ + --instance_prompt="a photo of sks dog" \ + --class_prompt="a photo of dog" \ + --resolution=512 \ + --train_batch_size=1 \ + --gradient_accumulation_steps=1 --gradient_checkpointing \ + --learning_rate=5e-6 \ + --lr_scheduler="constant" \ + --lr_warmup_steps=0 \ + --num_class_images=200 \ + --max_train_steps=800 \ + --mixed_precision=fp16 +``` ## Inference diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index 56deb4a7..c3b875a5 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -471,9 +471,17 @@ def main(): unet, optimizer, train_dataloader, lr_scheduler ) - # Move text_encode and vae to gpu - text_encoder.to(accelerator.device) - vae.to(accelerator.device) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + + # Move text_encode and vae to gpu. + # For mixed precision training we cast the text_encoder and vae weights to half-precision + # as these models are only used for inference, keeping weights in full precision is not required. + text_encoder.to(accelerator.device, dtype=weight_dtype) + vae.to(accelerator.device, dtype=weight_dtype) # We need to recalculate our total training steps as the size of the training dataloader may have changed. num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -509,11 +517,11 @@ def main(): with accelerator.accumulate(unet): # Convert images to latent space with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample() latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -539,12 +547,12 @@ def main(): loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() # Compute prior loss - prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") # Add the prior loss to the instance loss. loss = loss + args.prior_loss_weight * prior_loss else: - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") accelerator.backward(loss) if accelerator.sync_gradients: