option to swap training scheduler
This commit is contained in:
parent
4fae89fdee
commit
21361a3622
15
train.py
15
train.py
|
@ -41,7 +41,7 @@ import json
|
|||
from tqdm.auto import tqdm
|
||||
|
||||
from diffusers import StableDiffusionPipeline, AutoencoderKL, UNet2DConditionModel, DDIMScheduler, DDPMScheduler, \
|
||||
DPMSolverMultistepScheduler
|
||||
DPMSolverMultistepScheduler, PNDMScheduler
|
||||
#from diffusers.models import AttentionBlock
|
||||
from diffusers.optimization import get_scheduler
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
|
@ -73,6 +73,18 @@ from utils.sample_generator import SampleGenerator
|
|||
_SIGTERM_EXIT_CODE = 130
|
||||
_VERY_LARGE_NUMBER = 1e9
|
||||
|
||||
def get_training_noise_scheduler(train_sampler: str, model_root_folder, trained_betas=None):
|
||||
noise_scheduler = None
|
||||
if train_sampler.lower() == "pndm":
|
||||
logging.info(f" * Using PNDM noise scheduler for training: {train_sampler}")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
elif train_sampler.lower() == "ddpm":
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
else:
|
||||
logging.info(f" * Using default (DDPM) noise scheduler for training: {train_sampler}")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
return noise_scheduler
|
||||
|
||||
def get_hf_ckpt_cache_path(ckpt_path):
|
||||
return os.path.join("ckpt_cache", os.path.basename(ckpt_path))
|
||||
|
||||
|
@ -654,6 +666,7 @@ def main(args):
|
|||
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
noise_scheduler = get_training_noise_scheduler(args.train_sampler, model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
|
||||
else:
|
||||
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
|
||||
|
|
Loading…
Reference in New Issue