option to swap training scheduler

This commit is contained in:
Victor Hall 2023-11-05 20:54:09 -05:00
parent 4fae89fdee
commit 21361a3622
1 changed files with 14 additions and 1 deletions

View File

@ -41,7 +41,7 @@ import json
from tqdm.auto import tqdm
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, \
DPMSolverMultistepScheduler
DPMSolverMultistepScheduler, PNDMScheduler
#from diffusers.models import AttentionBlock
from diffusers.optimization import get_scheduler
from diffusers.utils.import_utils import is_xformers_available
@ -73,6 +73,18 @@ from utils.sample_generator import SampleGenerator
_SIGTERM_EXIT_CODE = 130
_VERY_LARGE_NUMBER = 1e9
def get_training_noise_scheduler(train_sampler: str, model_root_folder, trained_betas=None):
noise_scheduler = None
if train_sampler.lower() == "pndm":
logging.info(f" * Using PNDM noise scheduler for training: {train_sampler}")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
elif train_sampler.lower() == "ddpm":
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
logging.info(f" * Using default (DDPM) noise scheduler for training: {train_sampler}")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
return noise_scheduler
def get_hf_ckpt_cache_path(ckpt_path):
return os.path.join("ckpt_cache", os.path.basename(ckpt_path))
@ -654,6 +666,7 @@ def main(args):
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = get_training_noise_scheduler(args.train_sampler, model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")