Dreambooth: reduce VRAM usage (#2039)

* Dreambooth: use `optimizer.zero_grad(set_to_none=True)` to reduce VRAM usage

* Allow the user to control `optimizer.zero_grad(set_to_none=True)` with --set_grads_to_none

* Update Dreambooth readme

* Fix link in readme

* Fix header size in readme
This commit is contained in:
Gleb Akhmerov 2023-01-23 15:21:03 +04:00 committed by GitHub
parent 6fedb29f11
commit a66f2baeb7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 52 additions and 1 deletions

View File

@ -131,6 +131,42 @@ accelerate launch train_dreambooth.py \
--max_train_steps=800
```
### Training on a 12GB GPU:
It is possible to run dreambooth on a 12GB GPU by using the following optimizations:
- [gradient checkpointing and the 8-bit optimizer](#training-on-a-16gb-gpu)
- [xformers](#training-with-xformers)
- [setting grads to none](#set-grads-to-none)
```bash
export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export INSTANCE_DIR="path-to-instance-images"
export CLASS_DIR="path-to-class-images"
export OUTPUT_DIR="path-to-save-model"
accelerate launch train_dreambooth.py \
--pretrained_model_name_or_path=$MODEL_NAME \
--instance_data_dir=$INSTANCE_DIR \
--class_data_dir=$CLASS_DIR \
--output_dir=$OUTPUT_DIR \
--with_prior_preservation --prior_loss_weight=1.0 \
--instance_prompt="a photo of sks dog" \
--class_prompt="a photo of dog" \
--resolution=512 \
--train_batch_size=1 \
--gradient_accumulation_steps=1 --gradient_checkpointing \
--use_8bit_adam \
--enable_xformers_memory_efficient_attention \
--set_grads_to_none \
--learning_rate=2e-6 \
--lr_scheduler="constant" \
--lr_warmup_steps=0 \
--num_class_images=200 \
--max_train_steps=800
```
### Training on a 8 GB GPU:
By using [DeepSpeed](https://www.deepspeed.ai/) it's possible to offload some
@ -418,5 +454,11 @@ You can enable memory efficient attention by [installing xFormers](https://githu
You can also use Dreambooth to train the specialized in-painting model. See [the script in the research folder for details](https://github.com/huggingface/diffusers/tree/main/examples/research_projects/dreambooth_inpaint).
### Set grads to none
To save even more memory, pass the `--set_grads_to_none` argument to the script. This will set grads to None instead of zero. However, be aware that it changes certain behaviors, so if you start experiencing any problems, remove this argument.
More info: https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html
### Experimental results
You can refer to [this blog post](https://huggingface.co/blog/dreambooth) that discusses some of DreamBooth experiments in detail. Specifically, it recommends a set of DreamBooth-specific tips and tricks that we have found to work well for a variety of subjects.

View File

@ -312,6 +312,15 @@ def parse_args(input_args=None):
parser.add_argument(
"--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
)
parser.add_argument(
"--set_grads_to_none",
action="store_true",
help=(
"Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
" behaviors, so disable this argument if it causes any problems. More info:"
" https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
),
)
if input_args is not None:
args = parser.parse_args(input_args)
@ -828,7 +837,7 @@ def main(args):
accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad()
optimizer.zero_grad(set_to_none=args.set_grads_to_none)
# Checks if the accelerator has performed an optimization step behind the scenes
if accelerator.sync_gradients: