Scheduler Options & Resume Updates
This commit is contained in:
parent
56923359a3
commit
b303fdc293
|
@ -33,7 +33,7 @@ except pynvml.nvml.NVMLError_LibraryNotFound:
|
|||
pynvml = None
|
||||
|
||||
from typing import Iterable
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||
from diffusers.optimization import get_scheduler
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
@ -77,6 +77,8 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project
|
|||
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
|
||||
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
|
||||
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
||||
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
||||
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
||||
args = parser.parse_args()
|
||||
|
@ -586,7 +588,8 @@ def main():
|
|||
global_step = 0
|
||||
|
||||
if args.resume:
|
||||
global_step = int(args.resume.split('_')[-1])
|
||||
target_global_step = int(args.resume.split('_')[-1])
|
||||
print(f'resuming from {args.resume}...')
|
||||
|
||||
lr_scheduler = get_scheduler(
|
||||
args.lr_scheduler,
|
||||
|
@ -611,6 +614,7 @@ def main():
|
|||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
)
|
||||
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
|
||||
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
|
||||
# barrier
|
||||
torch.distributed.barrier()
|
||||
|
@ -621,6 +625,11 @@ def main():
|
|||
for epoch in range(args.epochs):
|
||||
unet.train()
|
||||
for _, batch in enumerate(train_dataloader):
|
||||
if args.resume and global_step < target_global_step:
|
||||
if rank == 0:
|
||||
progress_bar.update(1)
|
||||
global_step += 1
|
||||
continue
|
||||
b_start = time.perf_counter()
|
||||
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||
latents = latents * 0.18215
|
||||
|
@ -684,7 +693,7 @@ def main():
|
|||
"perf/global_samples_per_second": world_images_per_second,
|
||||
}
|
||||
progress_bar.set_postfix(logs)
|
||||
run.log(logs)
|
||||
run.log(logs, step=global_step)
|
||||
|
||||
if global_step % args.save_steps == 0:
|
||||
save_checkpoint(global_step)
|
||||
|
@ -693,15 +702,26 @@ def main():
|
|||
if rank == 0:
|
||||
# get prompt from random batch
|
||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||
|
||||
if args.image_log_scheduler == 'DDIMScheduler':
|
||||
print('using DDIMScheduler scheduler')
|
||||
scheduler = DDIMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||
)
|
||||
else:
|
||||
print('using PNDMScheduler scheduler')
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
)
|
||||
|
||||
scheduler.set_timesteps(num_inference_steps=35)
|
||||
pipeline = StableDiffusionPipeline(
|
||||
text_encoder=text_encoder,
|
||||
vae=vae,
|
||||
unet=unet,
|
||||
tokenizer=tokenizer,
|
||||
scheduler=PNDMScheduler(
|
||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||
),
|
||||
safety_checker=None, # display safety checker to save memory
|
||||
scheduler=scheduler,
|
||||
safety_checker=None, # disable safety checker to save memory
|
||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||
).to(device)
|
||||
# inference
|
||||
|
@ -709,9 +729,14 @@ def main():
|
|||
with torch.no_grad():
|
||||
with torch.autocast('cuda', enabled=args.fp16):
|
||||
for _ in range(args.image_log_amount):
|
||||
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
|
||||
images.append(
|
||||
wandb.Image(pipeline(
|
||||
prompt, num_inference_steps=args.image_log_inference_steps
|
||||
).images[0],
|
||||
caption=prompt)
|
||||
)
|
||||
# log images under single caption
|
||||
run.log({'images': images})
|
||||
run.log({'images': images}, step=global_step)
|
||||
|
||||
# cleanup so we don't run out of memory
|
||||
del pipeline
|
||||
|
|
Loading…
Reference in New Issue