Add checkpoint resuming, exception handling, and learning rate schedulers

This commit is contained in:
Anthony Mercurio 2022-11-06 01:05:40 -07:00 committed by GitHub
parent 2321e22fc1
commit 56923359a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 128 additions and 114 deletions

View File

@ -25,6 +25,7 @@ import itertools
import numpy as np
import json
import re
import traceback
try:
pynvml.nvmlInit()
@ -47,6 +48,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
# TODO: add custom VAE support. should be simple with diffusers
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.')
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
@ -63,6 +65,8 @@ parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon')
parser.add_argument('--lr_scheduler', type=str, default='cosine', help='Learning rate scheduler [`cosine`, `linear`, `constant`]')
parser.add_argument('--lr_scheduler_warmup', type=float, default=0.05, help='Learning rate scheduler warmup steps. This is a percentage of the total number of steps in the training run. 0.1 means 10 percent of the total number of steps.')
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.')
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
@ -93,17 +97,6 @@ def get_world_size() -> int:
return 1
return torch.distributed.get_world_size()
# Inform the user of host, and various versions -- useful for debugging isseus.
print("RUN_NAME:", args.run_name)
print("HOST:", socket.gethostname())
print("CUDA:", torch.version.cuda)
print("TORCH:", torch.__version__)
print("TRANSFORMERS:", transformers.__version__)
print("DIFFUSERS:", diffusers.__version__)
print("MODEL:", args.model)
print("FP16:", args.fp16)
print("RESOLUTION:", args.resolution)
def get_gpu_ram() -> str:
"""
Returns memory usage statistics for the CPU, GPU, and Torch.
@ -483,15 +476,24 @@ def main():
world_size = get_world_size()
torch.cuda.set_device(rank)
if args.hf_token is None:
args.hf_token = os.environ['HF_API_TOKEN']
if rank == 0:
os.makedirs(args.output_path, exist_ok=True)
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
# remove hf_token from args so sneaky people don't steal it from the wandb logs
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
# Inform the user of host, and various versions -- useful for debugging isseus.
print("RUN_NAME:", args.run_name)
print("HOST:", socket.gethostname())
print("CUDA:", torch.version.cuda)
print("TORCH:", torch.__version__)
print("TRANSFORMERS:", transformers.__version__)
print("DIFFUSERS:", diffusers.__version__)
print("MODEL:", args.model)
print("FP16:", args.fp16)
print("RESOLUTION:", args.resolution)
if args.hf_token is None:
args.hf_token = os.environ['HF_API_TOKEN']
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
device = torch.device('cuda')
@ -504,6 +506,9 @@ def main():
torch.manual_seed(args.seed)
print('RANDOM SEED:', args.seed)
if args.resume:
args.model = args.resume
tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token)
text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token)
vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token)
@ -561,11 +566,6 @@ def main():
collate_fn=dataset.collate_fn
)
lr_scheduler = get_scheduler(
'constant',
optimizer=optimizer
)
weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device
@ -585,7 +585,18 @@ def main():
progress_bar = tqdm.tqdm(range(args.epochs * num_steps_per_epoch), desc="Total Steps", leave=False)
global_step = 0
def save_checkpoint():
if args.resume:
global_step = int(args.resume.split('_')[-1])
lr_scheduler = get_scheduler(
args.lr_scheduler,
optimizer=optimizer,
num_warmup_steps=int(args.lr_scheduler_warmup * num_steps_per_epoch * args.epochs),
num_training_steps=args.epochs * num_steps_per_epoch,
#last_epoch=(global_step // num_steps_per_epoch) - 1,
)
def save_checkpoint(global_step):
if rank == 0:
if args.use_ema:
ema_unet.copy_to(unet.parameters())
@ -600,16 +611,16 @@ def main():
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
pipeline.save_pretrained(args.output_path)
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
# barrier
torch.distributed.barrier()
# train!
try:
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
for epoch in range(args.epochs):
unet.train()
train_loss = 0.0
for step, batch in enumerate(train_dataloader):
for _, batch in enumerate(train_dataloader):
b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
@ -667,6 +678,7 @@ def main():
"train/loss": loss.detach().item() / world_size,
"train/lr": lr_scheduler.get_last_lr()[0],
"train/epoch": epoch,
"train/step": global_step,
"train/samples_seen": samples_seen,
"perf/rank_samples_per_second": rank_images_per_second,
"perf/global_samples_per_second": world_images_per_second,
@ -675,7 +687,7 @@ def main():
run.log(logs)
if global_step % args.save_steps == 0:
save_checkpoint()
save_checkpoint(global_step)
if global_step % args.image_log_steps == 0:
if rank == 0:
@ -705,9 +717,11 @@ def main():
del pipeline
gc.collect()
torch.distributed.barrier()
except Exception as e:
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
pass
if rank == 0:
save_checkpoint()
save_checkpoint(global_step)
torch.distributed.barrier()
cleanup()