Scheduler Options & Resume Updates
This commit is contained in:
parent
56923359a3
commit
b303fdc293
|
@ -33,7 +33,7 @@ except pynvml.nvml.NVMLError_LibraryNotFound:
|
||||||
pynvml = None
|
pynvml = None
|
||||||
|
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline
|
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
|
||||||
from diffusers.optimization import get_scheduler
|
from diffusers.optimization import get_scheduler
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
@ -77,6 +77,8 @@ parser.add_argument('--project_id', type=str, default='diffusers', help='Project
|
||||||
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
|
parser.add_argument('--fp16', dest='fp16', type=bool, default=False, help='Train in mixed precision')
|
||||||
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
|
||||||
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
|
||||||
|
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
|
||||||
|
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
|
||||||
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
parser.add_argument('--clip_penultimate', type=bool, default=False, help='Use penultimate CLIP layer for text embedding')
|
||||||
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
parser.add_argument('--output_bucket_info', type=bool, default=False, help='Outputs bucket information and exits')
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
@ -586,7 +588,8 @@ def main():
|
||||||
global_step = 0
|
global_step = 0
|
||||||
|
|
||||||
if args.resume:
|
if args.resume:
|
||||||
global_step = int(args.resume.split('_')[-1])
|
target_global_step = int(args.resume.split('_')[-1])
|
||||||
|
print(f'resuming from {args.resume}...')
|
||||||
|
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
args.lr_scheduler,
|
args.lr_scheduler,
|
||||||
|
@ -611,6 +614,7 @@ def main():
|
||||||
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
)
|
)
|
||||||
|
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
|
||||||
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
|
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
|
||||||
# barrier
|
# barrier
|
||||||
torch.distributed.barrier()
|
torch.distributed.barrier()
|
||||||
|
@ -621,6 +625,11 @@ def main():
|
||||||
for epoch in range(args.epochs):
|
for epoch in range(args.epochs):
|
||||||
unet.train()
|
unet.train()
|
||||||
for _, batch in enumerate(train_dataloader):
|
for _, batch in enumerate(train_dataloader):
|
||||||
|
if args.resume and global_step < target_global_step:
|
||||||
|
if rank == 0:
|
||||||
|
progress_bar.update(1)
|
||||||
|
global_step += 1
|
||||||
|
continue
|
||||||
b_start = time.perf_counter()
|
b_start = time.perf_counter()
|
||||||
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
|
||||||
latents = latents * 0.18215
|
latents = latents * 0.18215
|
||||||
|
@ -684,7 +693,7 @@ def main():
|
||||||
"perf/global_samples_per_second": world_images_per_second,
|
"perf/global_samples_per_second": world_images_per_second,
|
||||||
}
|
}
|
||||||
progress_bar.set_postfix(logs)
|
progress_bar.set_postfix(logs)
|
||||||
run.log(logs)
|
run.log(logs, step=global_step)
|
||||||
|
|
||||||
if global_step % args.save_steps == 0:
|
if global_step % args.save_steps == 0:
|
||||||
save_checkpoint(global_step)
|
save_checkpoint(global_step)
|
||||||
|
@ -693,15 +702,26 @@ def main():
|
||||||
if rank == 0:
|
if rank == 0:
|
||||||
# get prompt from random batch
|
# get prompt from random batch
|
||||||
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
|
||||||
|
|
||||||
|
if args.image_log_scheduler == 'DDIMScheduler':
|
||||||
|
print('using DDIMScheduler scheduler')
|
||||||
|
scheduler = DDIMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
print('using PNDMScheduler scheduler')
|
||||||
|
scheduler=PNDMScheduler(
|
||||||
|
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
||||||
|
)
|
||||||
|
|
||||||
|
scheduler.set_timesteps(num_inference_steps=35)
|
||||||
pipeline = StableDiffusionPipeline(
|
pipeline = StableDiffusionPipeline(
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
vae=vae,
|
vae=vae,
|
||||||
unet=unet,
|
unet=unet,
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
scheduler=PNDMScheduler(
|
scheduler=scheduler,
|
||||||
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
|
safety_checker=None, # disable safety checker to save memory
|
||||||
),
|
|
||||||
safety_checker=None, # display safety checker to save memory
|
|
||||||
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
|
||||||
).to(device)
|
).to(device)
|
||||||
# inference
|
# inference
|
||||||
|
@ -709,9 +729,14 @@ def main():
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
with torch.autocast('cuda', enabled=args.fp16):
|
with torch.autocast('cuda', enabled=args.fp16):
|
||||||
for _ in range(args.image_log_amount):
|
for _ in range(args.image_log_amount):
|
||||||
images.append(wandb.Image(pipeline(prompt).images[0], caption=prompt))
|
images.append(
|
||||||
|
wandb.Image(pipeline(
|
||||||
|
prompt, num_inference_steps=args.image_log_inference_steps
|
||||||
|
).images[0],
|
||||||
|
caption=prompt)
|
||||||
|
)
|
||||||
# log images under single caption
|
# log images under single caption
|
||||||
run.log({'images': images})
|
run.log({'images': images}, step=global_step)
|
||||||
|
|
||||||
# cleanup so we don't run out of memory
|
# cleanup so we don't run out of memory
|
||||||
del pipeline
|
del pipeline
|
||||||
|
|
Loading…
Reference in New Issue