Merge pull request #58 from harubaru/text-encoder-updates

FEAT: Text Encoder (CLIP) Training
This commit is contained in:
Anthony Mercurio 2022-12-03 09:24:43 -07:00 committed by GitHub
commit 27d301c5b9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 57 additions and 20 deletions

View File

@ -92,6 +92,8 @@ parser.add_argument('--extended_validation', type=bool_t, default='False', help=
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.') parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.') parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.') parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
args = parser.parse_args() args = parser.parse_args()
@ -498,6 +500,9 @@ class AspectDataset(torch.utils.data.Dataset):
self.device = device self.device = device
self.ucg = ucg self.ucg = ucg
if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel:
self.text_encoder = self.text_encoder.module
self.transforms = torchvision.transforms.Compose([ self.transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.5), torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(), torchvision.transforms.ToTensor(),
@ -731,10 +736,13 @@ def main():
# Freeze vae and text_encoder # Freeze vae and text_encoder
vae.requires_grad_(False) vae.requires_grad_(False)
if not args.train_text_encoder:
text_encoder.requires_grad_(False) text_encoder.requires_grad_(False)
if args.gradient_checkpointing: if args.gradient_checkpointing:
unet.enable_gradient_checkpointing() unet.enable_gradient_checkpointing()
if args.train_text_encoder:
text_encoder.gradient_checkpointing_enable()
if args.use_xformers: if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True) unet.set_use_memory_efficient_attention_xformers(True)
@ -745,7 +753,7 @@ def main():
# move models to device # move models to device
vae = vae.to(device, dtype=weight_dtype) vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=torch.float32) unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype) text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32)
unet = torch.nn.parallel.DistributedDataParallel( unet = torch.nn.parallel.DistributedDataParallel(
unet, unet,
@ -754,6 +762,14 @@ def main():
gradient_as_bucket_view=True gradient_as_bucket_view=True
) )
if args.train_text_encoder:
text_encoder = torch.nn.parallel.DistributedDataParallel(
text_encoder,
device_ids=[rank],
output_device=rank,
gradient_as_bucket_view=True
)
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails. if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
try: try:
import bitsandbytes as bnb import bitsandbytes as bnb
@ -774,10 +790,12 @@ def main():
) )
""" """
optimizer_parameters = unet.parameters() if not args.train_text_encoder else itertools.chain(unet.parameters(), text_encoder.parameters())
# Create distributed optimizer # Create distributed optimizer
from torch.distributed.optim import ZeroRedundancyOptimizer from torch.distributed.optim import ZeroRedundancyOptimizer
optimizer = ZeroRedundancyOptimizer( optimizer = ZeroRedundancyOptimizer(
unet.parameters(), optimizer_parameters,
optimizer_class=optimizer_cls, optimizer_class=optimizer_cls,
parameters_as_bucket_view=True, parameters_as_bucket_view=True,
lr=args.lr, lr=args.lr,
@ -847,7 +865,7 @@ def main():
ema_unet.store(unet.parameters()) ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters()) ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module,
vae=vae, vae=vae,
unet=unet.module, unet=unet.module,
tokenizer=tokenizer, tokenizer=tokenizer,
@ -866,6 +884,8 @@ def main():
loss = torch.tensor(0.0, device=device, dtype=weight_dtype) loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
for epoch in range(args.epochs): for epoch in range(args.epochs):
unet.train() unet.train()
if args.train_text_encoder:
text_encoder.train()
for _, batch in enumerate(train_dataloader): for _, batch in enumerate(train_dataloader):
if args.resume and global_step < target_global_step: if args.resume and global_step < target_global_step:
if rank == 0: if rank == 0:
@ -898,6 +918,7 @@ def main():
else: else:
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}") raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
if not args.train_text_encoder:
with unet.join(): with unet.join():
# Predict the noise residual and compute loss # Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16): with torch.autocast('cuda', enabled=args.fp16):
@ -912,6 +933,22 @@ def main():
scaler.update() scaler.update()
lr_scheduler.step() lr_scheduler.step()
optimizer.zero_grad() optimizer.zero_grad()
else:
with unet.join(), text_encoder.join():
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
# backprop and update
scaler.scale(loss).backward()
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), 1.0)
scaler.step(optimizer)
scaler.update()
lr_scheduler.step()
optimizer.zero_grad()
# Update EMA # Update EMA
if args.use_ema: if args.use_ema:
@ -960,7 +997,7 @@ def main():
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token) scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module,
vae=vae, vae=vae,
unet=unet.module, unet=unet.module,
tokenizer=tokenizer, tokenizer=tokenizer,