diff --git a/examples/dreambooth/train_dreambooth.py b/examples/dreambooth/train_dreambooth.py index d3720f0c..9b7e1724 100644 --- a/examples/dreambooth/train_dreambooth.py +++ b/examples/dreambooth/train_dreambooth.py @@ -574,6 +574,8 @@ def main(args): for epoch in range(args.num_train_epochs): unet.train() + if args.train_text_encoder: + text_encoder.train() for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space