DreamBooth DeepSpeed support for under 8 GB VRAM training (#735)
* Support deepspeed * Dreambooth DeepSpeed documentation * Remove unnecessary casts, documentation Due to recent commits some casts to half precision are not necessary anymore. Mention that DeepSpeed's version of Adam is about 2x faster. * Review comments
This commit is contained in:
parent
71ca10c6a4
commit
81bdbb5e2a
|
@ -119,6 +119,46 @@ accelerate launch train_dreambooth.py \
|
|||
--max_train_steps=800
|
||||
```
|
||||
|
||||
### Training on a 8 GB GPU:
|
||||
|
||||
By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
|
||||
tensors from VRAM to either CPU or NVME allowing to train with less VRAM.
|
||||
|
||||
DeepSpeed needs to be enabled with `accelerate config`. During configuration
|
||||
answer yes to "Do you want to use DeepSpeed?". With DeepSpeed stage 2, fp16
|
||||
mixed precision and offloading both parameters and optimizer state to cpu it's
|
||||
possible to train on under 8 GB VRAM with a drawback of requiring significantly
|
||||
more RAM (about 25 GB). See [documentation](https://huggingface.co/docs/accelerate/usage_guides/deepspeed) for more DeepSpeed configuration options.
|
||||
|
||||
Changing the default Adam optimizer to DeepSpeed's special version of Adam
|
||||
`deepspeed.ops.adam.DeepSpeedCPUAdam` gives a substantial speedup but enabling
|
||||
it requires CUDA toolchain with the same version as pytorch. 8-bit optimizer
|
||||
does not seem to be compatible with DeepSpeed at the moment.
|
||||
|
||||
```bash
|
||||
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
|
||||
export INSTANCE_DIR="path-to-instance-images"
|
||||
export CLASS_DIR="path-to-class-images"
|
||||
export OUTPUT_DIR="path-to-save-model"
|
||||
|
||||
accelerate launch train_dreambooth.py \
|
||||
--pretrained_model_name_or_path=$MODEL_NAME --use_auth_token \
|
||||
--instance_data_dir=$INSTANCE_DIR \
|
||||
--class_data_dir=$CLASS_DIR \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--with_prior_preservation --prior_loss_weight=1.0 \
|
||||
--instance_prompt="a photo of sks dog" \
|
||||
--class_prompt="a photo of dog" \
|
||||
--resolution=512 \
|
||||
--train_batch_size=1 \
|
||||
--gradient_accumulation_steps=1 --gradient_checkpointing \
|
||||
--learning_rate=5e-6 \
|
||||
--lr_scheduler="constant" \
|
||||
--lr_warmup_steps=0 \
|
||||
--num_class_images=200 \
|
||||
--max_train_steps=800 \
|
||||
--mixed_precision=fp16
|
||||
```
|
||||
|
||||
## Inference
|
||||
|
||||
|
|
|
@ -471,9 +471,17 @@ def main():
|
|||
unet, optimizer, train_dataloader, lr_scheduler
|
||||
)
|
||||
|
||||
# Move text_encode and vae to gpu
|
||||
text_encoder.to(accelerator.device)
|
||||
vae.to(accelerator.device)
|
||||
weight_dtype = torch.float32
|
||||
if args.mixed_precision == "fp16":
|
||||
weight_dtype = torch.float16
|
||||
elif args.mixed_precision == "bf16":
|
||||
weight_dtype = torch.bfloat16
|
||||
|
||||
# Move text_encode and vae to gpu.
|
||||
# For mixed precision training we cast the text_encoder and vae weights to half-precision
|
||||
# as these models are only used for inference, keeping weights in full precision is not required.
|
||||
text_encoder.to(accelerator.device, dtype=weight_dtype)
|
||||
vae.to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
# We need to recalculate our total training steps as the size of the training dataloader may have changed.
|
||||
num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
|
||||
|
@ -509,11 +517,11 @@ def main():
|
|||
with accelerator.accumulate(unet):
|
||||
# Convert images to latent space
|
||||
with torch.no_grad():
|
||||
latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
|
||||
latents = vae.encode(batch["pixel_values"].to(dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
||||
# Sample noise that we'll add to the latents
|
||||
noise = torch.randn(latents.shape).to(latents.device)
|
||||
noise = torch.randn_like(latents)
|
||||
bsz = latents.shape[0]
|
||||
# Sample a random timestep for each image
|
||||
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device)
|
||||
|
@ -539,12 +547,12 @@ def main():
|
|||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
|
||||
# Compute prior loss
|
||||
prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
|
||||
prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
|
||||
|
||||
# Add the prior loss to the instance loss.
|
||||
loss = loss + args.prior_loss_weight * prior_loss
|
||||
else:
|
||||
loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
|
||||
loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||
|
||||
accelerator.backward(loss)
|
||||
if accelerator.sync_gradients:
|
||||
|
|
Loading…
Reference in New Issue