Scheduler Options & Resume Updates

This commit is contained in:
cafeai 2022-11-06 19:08:21 +09:00
parent 56923359a3
commit b303fdc293
1 changed files with 34 additions and 9 deletions

View File

@ -33,7 +33,7 @@ except pynvml.nvml.NVMLError_LibraryNotFound:
pynvml = None
from typing import Iterable
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
@ -77,6 +77,8 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
args = parser.parse_args()
@ -586,7 +588,8 @@ def main():
global_step = 0
if args.resume:
global_step = int(args.resume.split('_')[-1])
target_global_step = int(args.resume.split('_')[-1])
print(f'resuming from {args.resume}...')
lr_scheduler = get_scheduler(
args.lr_scheduler,
@ -611,6 +614,7 @@ def main():
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
# barrier
torch.distributed.barrier()
@ -621,6 +625,11 @@ def main():
for epoch in range(args.epochs):
unet.train()
for _, batch in enumerate(train_dataloader):
if args.resume and global_step < target_global_step:
if rank == 0:
progress_bar.update(1)
global_step += 1
continue
b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
@ -684,7 +693,7 @@ def main():
"perf/global_samples_per_second": world_images_per_second,
}
progress_bar.set_postfix(logs)
run.log(logs)
run.log(logs, step=global_step)
if global_step % args.save_steps == 0:
save_checkpoint(global_step)
@ -693,15 +702,26 @@ def main():
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
scheduler.set_timesteps(num_inference_steps=35)
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=None, # display safety checker to save memory
scheduler=scheduler,
safety_checker=None, # disable safety checker to save memory
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to(device)
# inference
@ -709,9 +729,14 @@ def main():
with torch.no_grad():
with torch.autocast('cuda', enabled=args.fp16):
for _ in range(args.image_log_amount):
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
images.append(
wandb.Image(pipeline(
prompt, num_inference_steps=args.image_log_inference_steps
).images[0],
caption=prompt)
)
# log images under single caption
run.log({'images': images})
run.log({'images': images}, step=global_step)
# cleanup so we don't run out of memory
del pipeline