Merge pull request #58 from harubaru/text-encoder-updates
FEAT: Text Encoder (CLIP) Training
This commit is contained in:
commit
27d301c5b9
|
@ -92,6 +92,8 @@ parser.add_argument('--extended_validation', type=bool_t, default='False', help=
|
||||||
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
parser.add_argument('--no_migration', type=bool_t, default='False', help='Do not perform migration of dataset while the `--resize` flag is active. Migration creates an adjacent folder to the dataset with <dataset_dirname>_cropped.')
|
||||||
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
parser.add_argument('--skip_validation', type=bool_t, default='False', help='Skip validation of images, useful for speeding up loading of very large datasets that have already been validated.')
|
||||||
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
|
parser.add_argument('--extended_mode_chunks', type=int, default=0, help='Enables extended mode for tokenization with given amount of maximum chunks. Values < 2 disable.')
|
||||||
|
parser.add_argument("--train_text_encoder", action="store_true", help="Whether to train the text encoder")
|
||||||
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
@ -498,6 +500,9 @@ class AspectDataset(torch.utils.data.Dataset):
|
||||||
self.device = device
|
self.device = device
|
||||||
self.ucg = ucg
|
self.ucg = ucg
|
||||||
|
|
||||||
|
if type(self.text_encoder) is torch.nn.parallel.DistributedDataParallel:
|
||||||
|
self.text_encoder = self.text_encoder.module
|
||||||
|
|
||||||
self.transforms = torchvision.transforms.Compose([
|
self.transforms = torchvision.transforms.Compose([
|
||||||
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
torchvision.transforms.RandomHorizontalFlip(p=0.5),
|
||||||
torchvision.transforms.ToTensor(),
|
torchvision.transforms.ToTensor(),
|
||||||
|
@ -731,11 +736,14 @@ def main():
|
||||||
|
|
||||||
# Freeze vae and text_encoder
|
# Freeze vae and text_encoder
|
||||||
vae.requires_grad_(False)
|
vae.requires_grad_(False)
|
||||||
text_encoder.requires_grad_(False)
|
if not args.train_text_encoder:
|
||||||
|
text_encoder.requires_grad_(False)
|
||||||
|
|
||||||
if args.gradient_checkpointing:
|
if args.gradient_checkpointing:
|
||||||
unet.enable_gradient_checkpointing()
|
unet.enable_gradient_checkpointing()
|
||||||
|
if args.train_text_encoder:
|
||||||
|
text_encoder.gradient_checkpointing_enable()
|
||||||
|
|
||||||
if args.use_xformers:
|
if args.use_xformers:
|
||||||
unet.set_use_memory_efficient_attention_xformers(True)
|
unet.set_use_memory_efficient_attention_xformers(True)
|
||||||
|
|
||||||
|
@ -745,7 +753,7 @@ def main():
|
||||||
# move models to device
|
# move models to device
|
||||||
vae = vae.to(device, dtype=weight_dtype)
|
vae = vae.to(device, dtype=weight_dtype)
|
||||||
unet = unet.to(device, dtype=torch.float32)
|
unet = unet.to(device, dtype=torch.float32)
|
||||||
text_encoder = text_encoder.to(device, dtype=weight_dtype)
|
text_encoder = text_encoder.to(device, dtype=weight_dtype if not args.train_text_encoder else torch.float32)
|
||||||
|
|
||||||
unet = torch.nn.parallel.DistributedDataParallel(
|
unet = torch.nn.parallel.DistributedDataParallel(
|
||||||
unet,
|
unet,
|
||||||
|
@ -753,7 +761,15 @@ def main():
|
||||||
output_device=rank,
|
output_device=rank,
|
||||||
gradient_as_bucket_view=True
|
gradient_as_bucket_view=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if args.train_text_encoder:
|
||||||
|
text_encoder = torch.nn.parallel.DistributedDataParallel(
|
||||||
|
text_encoder,
|
||||||
|
device_ids=[rank],
|
||||||
|
output_device=rank,
|
||||||
|
gradient_as_bucket_view=True
|
||||||
|
)
|
||||||
|
|
||||||
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
|
||||||
try:
|
try:
|
||||||
import bitsandbytes as bnb
|
import bitsandbytes as bnb
|
||||||
|
@ -774,10 +790,12 @@ def main():
|
||||||
)
|
)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
optimizer_parameters = unet.parameters() if not args.train_text_encoder else itertools.chain(unet.parameters(), text_encoder.parameters())
|
||||||
|
|
||||||
# Create distributed optimizer
|
# Create distributed optimizer
|
||||||
from torch.distributed.optim import ZeroRedundancyOptimizer
|
from torch.distributed.optim import ZeroRedundancyOptimizer
|
||||||
optimizer = ZeroRedundancyOptimizer(
|
optimizer = ZeroRedundancyOptimizer(
|
||||||
unet.parameters(),
|
optimizer_parameters,
|
||||||
optimizer_class=optimizer_cls,
|
optimizer_class=optimizer_cls,
|
||||||
parameters_as_bucket_view=True,
|
parameters_as_bucket_view=True,
|
||||||
lr=args.lr,
|
lr=args.lr,
|
||||||
|
@ -847,7 +865,7 @@ def main():
|
||||||
ema_unet.store(unet.parameters())
|
ema_unet.store(unet.parameters())
|
||||||
ema_unet.copy_to(unet.parameters())
|
ema_unet.copy_to(unet.parameters())
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet.module,
|
unet=unet.module,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
@ -866,6 +884,8 @@ def main():
|
||||||
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
|
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
unet.train()
|
unet.train()
|
||||||
|
if args.train_text_encoder:
|
||||||
|
text_encoder.train()
|
||||||
for _, batch in enumerate(train_dataloader):
|
for _, batch in enumerate(train_dataloader):
|
||||||
if args.resume and global_step < target_global_step:
|
if args.resume and global_step < target_global_step:
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
|
@ -898,20 +918,37 @@ def main():
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
|
raise ValueError(f"Unknown prediction type: {noise_scheduler.config.prediction_type}")
|
||||||
|
|
||||||
with unet.join():
|
if not args.train_text_encoder:
|
||||||
# Predict the noise residual and compute loss
|
with unet.join():
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
# Predict the noise residual and compute loss
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
# backprop and update
|
# backprop and update
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
else:
|
||||||
|
with unet.join(), text_encoder.join():
|
||||||
|
# Predict the noise residual and compute loss
|
||||||
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), target.float(), reduction="mean")
|
||||||
|
|
||||||
|
# backprop and update
|
||||||
|
scaler.scale(loss).backward()
|
||||||
|
torch.nn.utils.clip_grad_norm_(unet.parameters(), 1.0)
|
||||||
|
torch.nn.utils.clip_grad_norm_(text_encoder.parameters(), 1.0)
|
||||||
|
scaler.step(optimizer)
|
||||||
|
scaler.update()
|
||||||
|
lr_scheduler.step()
|
||||||
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Update EMA
|
# Update EMA
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
|
@ -960,7 +997,7 @@ def main():
|
||||||
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
|
scheduler=PNDMScheduler.from_pretrained(args.model, subfolder="scheduler", use_auth_token=args.hf_token)
|
||||||
|
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder if type(text_encoder) is not torch.nn.parallel.DistributedDataParallel else text_encoder.module,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet.module,
|
unet=unet.module,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
|
|
Loading…
Reference in New Issue