Scheduler Options & Resume Updates

This commit is contained in:
cafeai 2022-11-06 19:08:21 +09:00
parent 56923359a3
commit b303fdc293
1 changed files with 34 additions and 9 deletions

View File

@ -33,7 +33,7 @@ except pynvml.nvml.NVMLError_LibraryNotFound:
pynvml = None pynvml = None
from typing import Iterable from typing import Iterable
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -77,6 +77,8 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision') parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.') parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps') parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding') parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits') parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
args = parser.parse_args() args = parser.parse_args()
@ -586,7 +588,8 @@ def main():
global_step = 0 global_step = 0
if args.resume: if args.resume:
global_step = int(args.resume.split('_')[-1]) target_global_step = int(args.resume.split('_')[-1])
print(f'resuming from {args.resume}...')
lr_scheduler = get_scheduler( lr_scheduler = get_scheduler(
args.lr_scheduler, args.lr_scheduler,
@ -611,6 +614,7 @@ def main():
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"), safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
) )
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}') pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
# barrier # barrier
torch.distributed.barrier() torch.distributed.barrier()
@ -621,6 +625,11 @@ def main():
for epoch in range(args.epochs): for epoch in range(args.epochs):
unet.train() unet.train()
for _, batch in enumerate(train_dataloader): for _, batch in enumerate(train_dataloader):
if args.resume and global_step < target_global_step:
if rank == 0:
progress_bar.update(1)
global_step += 1
continue
b_start = time.perf_counter() b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample() latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215 latents = latents * 0.18215
@ -684,7 +693,7 @@ def main():
"perf/global_samples_per_second": world_images_per_second, "perf/global_samples_per_second": world_images_per_second,
} }
progress_bar.set_postfix(logs) progress_bar.set_postfix(logs)
run.log(logs) run.log(logs, step=global_step)
if global_step % args.save_steps == 0: if global_step % args.save_steps == 0:
save_checkpoint(global_step) save_checkpoint(global_step)
@ -693,15 +702,26 @@ def main():
if rank == 0: if rank == 0:
# get prompt from random batch # get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist()) prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
scheduler.set_timesteps(num_inference_steps=35)
pipeline = StableDiffusionPipeline( pipeline = StableDiffusionPipeline(
text_encoder=text_encoder, text_encoder=text_encoder,
vae=vae, vae=vae,
unet=unet, unet=unet,
tokenizer=tokenizer, tokenizer=tokenizer,
scheduler=PNDMScheduler( scheduler=scheduler,
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True safety_checker=None, # disable safety checker to save memory
),
safety_checker=None, # display safety checker to save memory
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to(device) ).to(device)
# inference # inference
@ -709,9 +729,14 @@ def main():
with torch.no_grad(): with torch.no_grad():
with torch.autocast('cuda', enabled=args.fp16): with torch.autocast('cuda', enabled=args.fp16):
for _ in range(args.image_log_amount): for _ in range(args.image_log_amount):
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt)) images.append(
wandb.Image(pipeline(
prompt, num_inference_steps=args.image_log_inference_steps
).images[0],
caption=prompt)
)
# log images under single caption # log images under single caption
run.log({'images': images}) run.log({'images': images}, step=global_step)
# cleanup so we don't run out of memory # cleanup so we don't run out of memory
del pipeline del pipeline