zero terminal fixes

This commit is contained in:
Victor Hall 2023-06-03 21:41:56 -04:00
parent 81b7b00df7
commit 1155a28867
3 changed files with 36 additions and 17 deletions

View File

@ -51,7 +51,9 @@ There are no known recommendations for the CLIP text encoder. Using an even lar
#### D-Adaption optimizers
[Dadaptation](https://arxiv.org/abs/2301.07733) [version](https://github.com/facebookresearch/dadaptation) of various optimizers. These require drastically different hyperparameters. Early indications seem to point to LR of 0.1 to 1.0 and weight decay of 0.8 may work well for these. There is a `decouple` parameter that appears to need to be set to `true` for dadaptation to work and is defaulted. Another `d0` parameter is defaulted to 1e-6 as suggested and, according to the paper authors, does not need to be tuned, but is optional. See `optimizer_dadapt.json` for an example of a fully configured `dadapt_adam` training.
[Dadaptation](https://arxiv.org/abs/2301.07733) [version](https://github.com/facebookresearch/dadaptation) of various optimizers.
These require drastically different hyperparameters. Early indications seem to point to LR of 0.1 to 1.0 and weight decay of 0.8 may work well. There is a `decouple` parameter that appears to need to be set to `true` for dadaptation to work and is defaulted. Another `d0` parameter is defaulted to 1e-6 as suggested and, according to the paper authors, does not need to be tuned, but is optional. See `optimizer_dadapt.json` for an example of a fully configured `dadapt_adam` training.
These are not memory efficient. You should use gradient checkpointing even with 24GB GPU.

View File

@ -313,11 +313,6 @@ class EveryDreamOptimizer():
if optimizer_name == "dadapt_adam":
opt_class = dadaptation.DAdaptAdam
elif optimizer_name == "dadapt_lion":
opt_class = dadaptation.DAdaptLion
elif optimizer_name == "dadapt_sgd":
opt_class = dadaptation.DAdaptSGD
optimizer = opt_class(
itertools.chain(parameters),
lr=curr_lr,
@ -328,8 +323,29 @@ class EveryDreamOptimizer():
log_every=args.log_step,
growth_rate=1e5,
decouple=decouple,
momentum=momentum,
)
elif optimizer_name == "dadapt_lion":
opt_class = dadaptation.DAdaptLion
optimizer = opt_class(
itertools.chain(parameters),
lr=curr_lr,
betas=(betas[0], betas[1]),
weight_decay=weight_decay,
d0=d0,
log_every=args.log_step,
)
elif optimizer_name == "dadapt_sgd":
opt_class = dadaptation.DAdaptSGD
optimizer = opt_class(
itertools.chain(parameters),
lr=curr_lr,
momentum=momentum,
weight_decay=weight_decay,
d0=d0,
log_every=args.log_step,
growth_rate=float("inf"),
)
else:
import bitsandbytes as bnb
opt_class = bnb.optim.AdamW8bit

View File

@ -452,14 +452,15 @@ def main(args):
unet = pipe.unet
del pipe
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
if args.zero_frequency_noise_ratio == -1.0:
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
from utils.unet_utils import enforce_zero_terminal_snr
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", prediction_type="v_prediction")
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", prediction_type="v_prediction", trained_betas=trained_betas)
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
reference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)