Add checkpoint resuming, exception handling, and learning rate schedulers
This commit is contained in:
parent
2321e22fc1
commit
56923359a3
|
@ -25,6 +25,7 @@ import itertools
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
|
import traceback
|
||||||
|
|
||||||
try:
|
try:
|
||||||
pynvml.nvmlInit()
|
pynvml.nvmlInit()
|
||||||
|
@ -47,6 +48,7 @@ torch.backends.cuda.matmul.allow_tf32 = True
|
||||||
# TODO: add custom VAE support. should be simple with diffusers
|
# TODO: add custom VAE support. should be simple with diffusers
|
||||||
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
|
||||||
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
|
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
|
||||||
|
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.')
|
||||||
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
|
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
|
||||||
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
|
parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
|
||||||
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
|
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
|
||||||
|
@ -63,6 +65,8 @@ parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
|
||||||
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
|
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
|
||||||
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
|
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
|
||||||
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon')
|
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon')
|
||||||
|
parser.add_argument('--lr_scheduler', type=str, default='cosine', help='Learning rate scheduler [`cosine`, `linear`, `constant`]')
|
||||||
|
parser.add_argument('--lr_scheduler_warmup', type=float, default=0.05, help='Learning rate scheduler warmup steps. This is a percentage of the total number of steps in the training run. 0.1 means 10 percent of the total number of steps.')
|
||||||
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.')
|
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.')
|
||||||
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
|
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
|
||||||
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
|
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
|
||||||
|
@ -93,17 +97,6 @@ def get_world_size() -> int:
|
||||||
return 1
|
return 1
|
||||||
return torch.distributed.get_world_size()
|
return torch.distributed.get_world_size()
|
||||||
|
|
||||||
# Inform the user of host, and various versions -- useful for debugging isseus.
|
|
||||||
print("RUN_NAME:", args.run_name)
|
|
||||||
print("HOST:", socket.gethostname())
|
|
||||||
print("CUDA:", torch.version.cuda)
|
|
||||||
print("TORCH:", torch.__version__)
|
|
||||||
print("TRANSFORMERS:", transformers.__version__)
|
|
||||||
print("DIFFUSERS:", diffusers.__version__)
|
|
||||||
print("MODEL:", args.model)
|
|
||||||
print("FP16:", args.fp16)
|
|
||||||
print("RESOLUTION:", args.resolution)
|
|
||||||
|
|
||||||
def get_gpu_ram() -> str:
|
def get_gpu_ram() -> str:
|
||||||
"""
|
"""
|
||||||
Returns memory usage statistics for the CPU, GPU, and Torch.
|
Returns memory usage statistics for the CPU, GPU, and Torch.
|
||||||
|
@ -483,15 +476,24 @@ def main():
|
||||||
world_size = get_world_size()
|
world_size = get_world_size()
|
||||||
torch.cuda.set_device(rank)
|
torch.cuda.set_device(rank)
|
||||||
|
|
||||||
if args.hf_token is None:
|
|
||||||
args.hf_token = os.environ['HF_API_TOKEN']
|
|
||||||
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
os.makedirs(args.output_path, exist_ok=True)
|
os.makedirs(args.output_path, exist_ok=True)
|
||||||
|
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
|
||||||
|
|
||||||
# remove hf_token from args so sneaky people don't steal it from the wandb logs
|
# Inform the user of host, and various versions -- useful for debugging isseus.
|
||||||
sanitized_args = {k: v for k, v in vars(args).items() if k not in ['hf_token']}
|
print("RUN_NAME:", args.run_name)
|
||||||
run = wandb.init(project=args.project_id, name=args.run_name, config=sanitized_args, dir=args.output_path+'/wandb')
|
print("HOST:", socket.gethostname())
|
||||||
|
print("CUDA:", torch.version.cuda)
|
||||||
|
print("TORCH:", torch.__version__)
|
||||||
|
print("TRANSFORMERS:", transformers.__version__)
|
||||||
|
print("DIFFUSERS:", diffusers.__version__)
|
||||||
|
print("MODEL:", args.model)
|
||||||
|
print("FP16:", args.fp16)
|
||||||
|
print("RESOLUTION:", args.resolution)
|
||||||
|
|
||||||
|
if args.hf_token is None:
|
||||||
|
args.hf_token = os.environ['HF_API_TOKEN']
|
||||||
|
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
|
||||||
|
|
||||||
device = torch.device('cuda')
|
device = torch.device('cuda')
|
||||||
|
|
||||||
|
@ -504,6 +506,9 @@ def main():
|
||||||
torch.manual_seed(args.seed)
|
torch.manual_seed(args.seed)
|
||||||
print('RANDOM SEED:', args.seed)
|
print('RANDOM SEED:', args.seed)
|
||||||
|
|
||||||
|
if args.resume:
|
||||||
|
args.model = args.resume
|
||||||
|
|
||||||
tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token)
|
tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token)
|
||||||
text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token)
|
text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token)
|
||||||
vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token)
|
vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token)
|
||||||
|
@ -561,11 +566,6 @@ def main():
|
||||||
collate_fn=dataset.collate_fn
|
collate_fn=dataset.collate_fn
|
||||||
)
|
)
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
|
||||||
'constant',
|
|
||||||
optimizer=optimizer
|
|
||||||
)
|
|
||||||
|
|
||||||
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
weight_dtype = torch.float16 if args.fp16 else torch.float32
|
||||||
|
|
||||||
# move models to device
|
# move models to device
|
||||||
|
@ -585,7 +585,18 @@ def main():
|
||||||
progress_bar = tqdm.tqdm(range(args.epochs * num_steps_per_epoch), desc="Total Steps", leave=False)
|
progress_bar = tqdm.tqdm(range(args.epochs * num_steps_per_epoch), desc="Total Steps", leave=False)
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
def save_checkpoint():
|
if args.resume:
|
||||||
|
global_step = int(args.resume.split('_')[-1])
|
||||||
|
|
||||||
|
lr_scheduler = get_scheduler(
|
||||||
|
args.lr_scheduler,
|
||||||
|
optimizer=optimizer,
|
||||||
|
num_warmup_steps=int(args.lr_scheduler_warmup * num_steps_per_epoch * args.epochs),
|
||||||
|
num_training_steps=args.epochs * num_steps_per_epoch,
|
||||||
|
#last_epoch=(global_step // num_steps_per_epoch) - 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
def save_checkpoint(global_step):
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet.copy_to(unet.parameters())
|
ema_unet.copy_to(unet.parameters())
|
||||||
|
@ -600,114 +611,117 @@ def main():
|
||||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
)
|
)
|
||||||
pipeline.save_pretrained(args.output_path)
|
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
|
||||||
# barrier
|
# barrier
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
|
||||||
# train!
|
# train!
|
||||||
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
|
try:
|
||||||
for epoch in range(args.epochs):
|
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
|
||||||
unet.train()
|
for epoch in range(args.epochs):
|
||||||
train_loss = 0.0
|
unet.train()
|
||||||
for step, batch in enumerate(train_dataloader):
|
for _, batch in enumerate(train_dataloader):
|
||||||
b_start = time.perf_counter()
|
b_start = time.perf_counter()
|
||||||
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
|
|
||||||
# Sample noise
|
# Sample noise
|
||||||
noise = torch.randn_like(latents)
|
noise = torch.randn_like(latents)
|
||||||
bsz = latents.shape[0]
|
bsz = latents.shape[0]
|
||||||
# Sample a random timestep for each image
|
# Sample a random timestep for each image
|
||||||
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
|
||||||
timesteps = timesteps.long()
|
timesteps = timesteps.long()
|
||||||
|
|
||||||
# Add noise to the latents according to the noise magnitude at each timestep
|
# Add noise to the latents according to the noise magnitude at each timestep
|
||||||
# (this is the forward diffusion process)
|
# (this is the forward diffusion process)
|
||||||
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
|
||||||
|
|
||||||
# Get the text embedding for conditioning
|
# Get the text embedding for conditioning
|
||||||
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
|
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
|
||||||
if args.clip_penultimate:
|
if args.clip_penultimate:
|
||||||
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
|
||||||
else:
|
else:
|
||||||
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
encoder_hidden_states = encoder_hidden_states.last_hidden_state
|
||||||
|
|
||||||
# Predict the noise residual and compute loss
|
# Predict the noise residual and compute loss
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
||||||
|
|
||||||
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
|
||||||
|
|
||||||
# Backprop and all reduce
|
# Backprop and all reduce
|
||||||
scaler.scale(loss).backward()
|
scaler.scale(loss).backward()
|
||||||
scaler.step(optimizer)
|
scaler.step(optimizer)
|
||||||
scaler.update()
|
scaler.update()
|
||||||
lr_scheduler.step()
|
lr_scheduler.step()
|
||||||
optimizer.zero_grad()
|
optimizer.zero_grad()
|
||||||
|
|
||||||
# Update EMA
|
# Update EMA
|
||||||
if args.use_ema:
|
if args.use_ema:
|
||||||
ema_unet.step(unet.parameters())
|
ema_unet.step(unet.parameters())
|
||||||
|
|
||||||
# perf
|
# perf
|
||||||
b_end = time.perf_counter()
|
b_end = time.perf_counter()
|
||||||
seconds_per_step = b_end - b_start
|
seconds_per_step = b_end - b_start
|
||||||
steps_per_second = 1 / seconds_per_step
|
steps_per_second = 1 / seconds_per_step
|
||||||
rank_images_per_second = args.batch_size * steps_per_second
|
rank_images_per_second = args.batch_size * steps_per_second
|
||||||
world_images_per_second = rank_images_per_second * world_size
|
world_images_per_second = rank_images_per_second * world_size
|
||||||
samples_seen = global_step * args.batch_size * world_size
|
samples_seen = global_step * args.batch_size * world_size
|
||||||
|
|
||||||
# All reduce loss
|
# All reduce loss
|
||||||
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
|
||||||
|
|
||||||
if rank == 0:
|
|
||||||
progress_bar.update(1)
|
|
||||||
global_step += 1
|
|
||||||
logs = {
|
|
||||||
"train/loss": loss.detach().item() / world_size,
|
|
||||||
"train/lr": lr_scheduler.get_last_lr()[0],
|
|
||||||
"train/epoch": epoch,
|
|
||||||
"train/samples_seen": samples_seen,
|
|
||||||
"perf/rank_samples_per_second": rank_images_per_second,
|
|
||||||
"perf/global_samples_per_second": world_images_per_second,
|
|
||||||
}
|
|
||||||
progress_bar.set_postfix(logs)
|
|
||||||
run.log(logs)
|
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
|
||||||
save_checkpoint()
|
|
||||||
|
|
||||||
if global_step % args.image_log_steps == 0:
|
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# get prompt from random batch
|
progress_bar.update(1)
|
||||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
global_step += 1
|
||||||
pipeline = StableDiffusionPipeline(
|
logs = {
|
||||||
text_encoder=text_encoder,
|
"train/loss": loss.detach().item() / world_size,
|
||||||
vae=vae,
|
"train/lr": lr_scheduler.get_last_lr()[0],
|
||||||
unet=unet,
|
"train/epoch": epoch,
|
||||||
tokenizer=tokenizer,
|
"train/step": global_step,
|
||||||
scheduler=PNDMScheduler(
|
"train/samples_seen": samples_seen,
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
"perf/rank_samples_per_second": rank_images_per_second,
|
||||||
),
|
"perf/global_samples_per_second": world_images_per_second,
|
||||||
safety_checker=None, # display safety checker to save memory
|
}
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
progress_bar.set_postfix(logs)
|
||||||
).to(device)
|
run.log(logs)
|
||||||
# inference
|
|
||||||
images = []
|
|
||||||
with torch.no_grad():
|
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
|
||||||
for _ in range(args.image_log_amount):
|
|
||||||
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
|
|
||||||
# log images under single caption
|
|
||||||
run.log({'images': images})
|
|
||||||
|
|
||||||
# cleanup so we don't run out of memory
|
if global_step % args.save_steps == 0:
|
||||||
del pipeline
|
save_checkpoint(global_step)
|
||||||
gc.collect()
|
|
||||||
torch.distributed.barrier()
|
|
||||||
|
|
||||||
if rank == 0:
|
if global_step % args.image_log_steps == 0:
|
||||||
save_checkpoint()
|
if rank == 0:
|
||||||
|
# get prompt from random batch
|
||||||
|
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||||
|
pipeline = StableDiffusionPipeline(
|
||||||
|
text_encoder=text_encoder,
|
||||||
|
vae=vae,
|
||||||
|
unet=unet,
|
||||||
|
tokenizer=tokenizer,
|
||||||
|
scheduler=PNDMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||||
|
),
|
||||||
|
safety_checker=None, # display safety checker to save memory
|
||||||
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
|
).to(device)
|
||||||
|
# inference
|
||||||
|
images = []
|
||||||
|
with torch.no_grad():
|
||||||
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
|
for _ in range(args.image_log_amount):
|
||||||
|
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
|
||||||
|
# log images under single caption
|
||||||
|
run.log({'images': images})
|
||||||
|
|
||||||
|
# cleanup so we don't run out of memory
|
||||||
|
del pipeline
|
||||||
|
gc.collect()
|
||||||
|
torch.distributed.barrier()
|
||||||
|
except Exception as e:
|
||||||
|
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
|
||||||
|
pass
|
||||||
|
|
||||||
|
save_checkpoint(global_step)
|
||||||
|
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
cleanup()
|
cleanup()
|
||||||
|
|
Loading…
Reference in New Issue