Implement Text Encoder Training

This commit is contained in:
cafeai 2022-12-03 12:47:40 +09:00
parent c709257bec
commit 18ff256be5
1 changed files with 52 additions and 17 deletions

View File

@ -92,6 +92,8 @@ parser.add_argument('--extended_validation', type=bool_t, default='False', help=
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
args = parser.parse_args()
@ -731,11 +733,15 @@ def main():
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True)
@ -753,7 +759,15 @@ def main():
output_device=rank,
gradient_as_bucket_view=True
)
if args.train_text_encoder:
text_encoder = torch.nn.parallel.DistributedDataParallel(
text_encoder,
device_ids=[rank],
output_device=rank,
gradient_as_bucket_view=True
)
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try:
import bitsandbytes as bnb
@ -774,10 +788,12 @@ def main():
)
"""
optimizer_parameters = unet.parameters() if not args.train_text_encoder else itertools.chain(unet.parameters(), text_encoder.parameters())
# Create distributed optimizer
from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer(
unet.parameters(),
optimizer_parameters,
optimizer_class=optimizer_cls,
parameters_as_bucket_view=True,
lr=args.lr,
@ -866,6 +882,8 @@ def main():
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
for epoch in range(args.epochs):
unet.train()
if args.train_text_encoder:
text_encoder.train()
for _, batch in enumerate(train_dataloader):
if args.resume and global_step < target_global_step:
if rank == 0:
@ -898,20 +916,37 @@ def main():
else:
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
with unet.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
if not args.train_text_encoder:
with unet.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
# backprop and update
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
# backprop and update
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
else:
with unet.join(), text_encoder.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
# backprop and update
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
# Update EMA
if args.use_ema: