allow scheduler change for training

This commit is contained in:
Victor Hall 2023-11-05 21:14:54 -05:00
parent 21361a3622
commit 6ea721887c
2 changed files with 8 additions and 6 deletions

View File

@ -65,7 +65,7 @@ class EveryDreamOptimizer():
self.unet_params = unet.parameters()
with torch.no_grad():
log_action = lambda n, label: logging.info(f"{Fore.LIGHTBLUE_EX} {label} weight normal: {n}{Style.RESET_ALL}")
log_action = lambda n, label: logging.info(f"{Fore.LIGHTBLUE_EX} {label} weight normal: {n:.1f}{Style.RESET_ALL}")
self._log_weight_normal(text_encoder.text_model.encoder.layers.parameters(), "text encoder", log_action)
self._log_weight_normal(unet.parameters(), "unet", log_action)

View File

@ -77,9 +77,10 @@ def get_training_noise_scheduler(train_sampler: str, model_root_folder, trained_
noise_scheduler = None
if train_sampler.lower() == "pndm":
logging.info(f" * Using PNDM noise scheduler for training: {train_sampler}")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
elif train_sampler.lower() == "ddpm":
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = PNDMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
elif train_sampler.lower() == "ddim":
logging.info(f" * Using DDIM noise scheduler for training: {train_sampler}")
noise_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
logging.info(f" * Using default (DDPM) noise scheduler for training: {train_sampler}")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
@ -666,10 +667,10 @@ def main(args):
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = get_training_noise_scheduler(args.train_sampler, model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = get_training_noise_scheduler(args.train_sampler, model_root_folder, trained_betas=trained_betas)
else:
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = get_training_noise_scheduler(args.train_sampler, model_root_folder)
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
@ -1357,6 +1358,7 @@ if __name__ == "__main__":
argparser.add_argument("--save_optimizer", action="store_true", default=False, help="saves optimizer state with ckpt, useful for resuming training later")
argparser.add_argument("--seed", type=int, default=555, help="seed used for samples and shuffling, use -1 for random")
argparser.add_argument("--shuffle_tags", action="store_true", default=False, help="randomly shuffles CSV tags in captions, for booru datasets")
argparser.add_argument("--train_sampler", type=str, default="ddpm", help="sampler used for training, (default: ddpm)", choices=["ddpm", "pndm", "ddim"])
argparser.add_argument("--keep_tags", type=int, default=0, help="Number of tags to keep when shuffle, def: 0 (shuffle all)")
argparser.add_argument("--useadam8bit", action="store_true", default=False, help="deprecated, use --optimizer_config and optimizer.json instead")
argparser.add_argument("--wandb", action="store_true", default=False, help="enable wandb logging instead of tensorboard, requires env var WANDB_API_KEY")