[DreamBooth] Set train mode for text encoder (#1012)

Set train mode for text encoder
This commit is contained in:
Duong A. Nguyen 2022-10-27 19:19:13 +07:00 committed by GitHub
parent abe058221c
commit 4623f095f3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 2 additions and 0 deletions

View File

@ -574,6 +574,8 @@ def main(args):
for epoch in range(args.num_train_epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for step, batch in enumerate(train_dataloader):
with accelerator.accumulate(unet):
# Convert images to latent space