diffusers/examples/text_to_image
Yasyf Mohamedali 856331c61b
Support training SD V2 with Flax (#1783)
* Support training SD V2 with Flax

Mostly involves supporting a v_prediction scheduler.

The implementation in #1777 doesn't take into account a recent refactor of `scheduling_utils_flax`, so this should be used instead.

* Add to other top-level files.
2023-01-04 13:19:22 +01:00
..
README.md Make xformers optional even if it is available (#1753) 2022-12-27 19:47:50 +01:00
requirements.txt Update requirements.txt (#1623) 2022-12-09 08:35:27 +01:00
requirements_flax.txt Update requirements.txt (#1623) 2022-12-09 08:35:27 +01:00
train_text_to_image.py Fix --resume_from_checkpoint step in train_text_to_image.py (#1914) 2023-01-04 13:05:55 +01:00
train_text_to_image_flax.py Support training SD V2 with Flax (#1783) 2023-01-04 13:19:22 +01:00

README.md

Stable Diffusion text-to-image fine-tuning

The train_text_to_image.py script shows how to fine-tune stable diffusion model on your own dataset.

Note:

This script is experimental. The script fine-tunes the whole model and often times the model overfits and runs into issues like catastrophic forgetting. It's recommended to try different hyperparamters to get the best result on your dataset.

Running locally with PyTorch

Installing the dependencies

Before running the scripts, make sure to install the library's training dependencies:

Important

To make sure you can successfully run the latest versions of the example scripts, we highly recommend installing from source and keeping the install up to date as we update the example scripts frequently and install some example-specific requirements. To do this, execute the following steps in a new virtual environment:

git clone https://github.com/huggingface/diffusers
cd diffusers
pip install .

Then cd in the example folder and run

pip install -r requirements.txt

And initialize an 🤗Accelerate environment with:

accelerate config

Pokemon example

You need to accept the model license before downloading or using the weights. In this example we'll use model version v1-4, so you'll need to visit its card, read the license and tick the checkbox if you agree.

You have to be a registered user in 🤗 Hugging Face Hub, and you'll also need to use an access token for the code to work. For more information on access tokens, please refer to this section of the documentation.

Run the following command to authenticate your token

huggingface-cli login

If you have already cloned the repo, then you won't need to go through these steps.


Hardware

With gradient_checkpointing and mixed_precision it should be possible to fine tune the model on a single 24GB GPU. For higher batch_size and faster training it's better to use GPUs with >30GB memory.

Note: Change the resolution to 768 if you are using the stable-diffusion-2 768x768 model.

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export dataset_name="lambdalabs/pokemon-blip-captions"

accelerate launch --mixed_precision="fp16"  train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="sd-pokemon-model" 

To run on your own training files prepare the dataset according to the format required by datasets, you can find the instructions for how to do that in this document. If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.

export MODEL_NAME="CompVis/stable-diffusion-v1-4"
export TRAIN_DIR="path_to_your_dataset"

accelerate launch --mixed_precision="fp16" train_text_to_image.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --use_ema \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --gradient_accumulation_steps=4 \
  --gradient_checkpointing \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --lr_scheduler="constant" --lr_warmup_steps=0 \
  --output_dir="sd-pokemon-model"

Once the training is finished the model will be saved in the output_dir specified in the command. In this example it's sd-pokemon-model. To load the fine-tuned model for inference just pass that path to StableDiffusionPipeline

from diffusers import StableDiffusionPipeline

model_path = "path_to_saved_model"
pipe = StableDiffusionPipeline.from_pretrained(model_path, torch_dtype=torch.float16)
pipe.to("cuda")

image = pipe(prompt="yoda").images[0]
image.save("yoda-pokemon.png")

Training with Flax/JAX

For faster training on TPUs and GPUs you can leverage the flax training example. Follow the instructions above to get the model and dataset before running the script.

_Note: The flax example don't yet support features like gradient checkpoint, gradient accumulation etc, so to use flax for faster training we will need >30GB cards.

Before running the scripts, make sure to install the library's training dependencies:

pip install -U -r requirements_flax.txt
export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export dataset_name="lambdalabs/pokemon-blip-captions"

python train_text_to_image_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --dataset_name=$dataset_name \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --output_dir="sd-pokemon-model" 

To run on your own training files prepare the dataset according to the format required by datasets, you can find the instructions for how to do that in this document. If you wish to use custom loading logic, you should modify the script, we have left pointers for that in the training script.

export MODEL_NAME="duongna/stable-diffusion-v1-4-flax"
export TRAIN_DIR="path_to_your_dataset"

python train_text_to_image_flax.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR \
  --resolution=512 --center_crop --random_flip \
  --train_batch_size=1 \
  --mixed_precision="fp16" \
  --max_train_steps=15000 \
  --learning_rate=1e-05 \
  --max_grad_norm=1 \
  --output_dir="sd-pokemon-model"

Training with xformers:

You can enable memory efficient attention by installing xFormers and padding the --enable_xformers_memory_efficient_attention argument to the script. This is not available with the Flax/JAX implementation.