Merge pull request #224 from a-l-e-x-d-s-9/main

EMA decay, min-SNR-gamma, settings loading every epoch, zero terminal SNR separation from ZFN.
This commit is contained in:
Victor Hall 2023-09-12 15:36:27 -04:00 committed by GitHub
commit 49de395df1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 453 additions and 74 deletions

View File

@ -149,11 +149,7 @@ Based on [Nicholas Guttenberg's blog post](https://www.crosslabs.org//blog/diffu
Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/zero_freq_test_biggs.webp
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
## Zero terminal SNR
Set `zero_frequency_noise_ratio` to -1.
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
## Keeping images together (custom batching)
@ -205,3 +201,37 @@ Clips the gradient normals to a maximum value. Default is None (no clipping).
Default is no gradient normal clipping. There are also other ways to deal with gradient explosion, such as increasing optimizer epsilon.
## Zero Terminal SNR
**Parameter:** `--enable_zero_terminal_snr`
**Default:** `False`
To enable zero terminal SNR.
## Dynamic Configuration Loading
**Parameter:** `--load_settings_every_epoch`
**Default:** `False`
Most of the parameters in the train.json file CANNOT be modified during training. Activate this to have the `train.json` configuration file reloaded at the start of each epoch. The following parameter can be changed and will be applied after the start of a new epoch:
- `--save_every_n_epochs`
- `--save_ckpts_from_n_epochs`
- `--save_full_precision`
- `--save_optimizer`
- `--zero_frequency_noise_ratio`
- `--min_snr_gamma`
- `--clip_skip`
## Min-SNR-Gamma Parameter
**Parameter:** `--min_snr_gamma`
**Recommended Values:** 5, 1, 20
**Default:** `None`
To enable min-SNR-Gamma. For an in-depth understanding, consult this [research paper](https://arxiv.org/abs/2303.09556).
## EMA Decay Features
The Exponential Moving Average (EMA) model is copied from the base model at the start and is updated every interval of steps by a small contribution from training.
In this mode, the EMA model will be saved alongside the regular checkpoint from training. Normal training checkpoint can be loaded with `--resume_ckpt`, and the EMA model can be loaded with `--ema_decay_resume_model`.
**Parameters:**
- `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature.
- `--ema_decay_target`: Set the EMA decay target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA decay feature.
- `--ema_decay_interval`: Set the interval in steps between EMA decay updates. The update occurs at each `global_steps` modulo `decay_interval`.
- `--ema_decay_device`: Choose between `cpu` and `cuda` for EMA decay. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM.
- `--ema_decay_sample_raw_training`: Activate to display samples from the trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
- `--ema_decay_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled.
- `--ema_decay_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training.

View File

@ -40,5 +40,15 @@
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02
"zero_frequency_noise_ratio": 0.02,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_decay_target": null,
"ema_decay_interval": null,
"ema_decay_device": null,
"ema_decay_sample_raw_training": false,
"ema_decay_sample_ema_model": false,
"ema_decay_resume_model" : null
}

453
train.py
View File

@ -61,6 +61,7 @@ from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
from utils.check_git import check_git
from optimizer.optimizers import EveryDreamOptimizer
from copy import deepcopy
if torch.cuda.is_available():
from utils.gpu import GPU
@ -186,7 +187,7 @@ def set_args_12gb(args):
logging.info(" - Overiding resolution to max 512")
args.resolution = 512
def find_last_checkpoint(logdir):
def find_last_checkpoint(logdir, is_ema=False):
"""
Finds the last checkpoint in the logdir, recursively
"""
@ -196,6 +197,12 @@ def find_last_checkpoint(logdir):
for root, dirs, files in os.walk(logdir):
for file in files:
if os.path.basename(file) == "model_index.json":
if is_ema and (not root.endswith("_ema")):
continue
elif (not is_ema) and root.endswith("_ema"):
continue
curr_date = os.path.getmtime(os.path.join(root,file))
if last_date is None or curr_date > last_date:
@ -228,6 +235,11 @@ def setup_args(args):
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if (args.ema_decay_resume_model != None) and (args.ema_decay_resume_model == "findlast"):
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
args.ema_decay_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
if args.lowvram:
set_args_12gb(args)
@ -356,6 +368,75 @@ def log_args(log_writer, args):
arglog += f"{arg}={value}, "
log_writer.add_text("config", arglog)
def update_ema(model, ema_model, decay, default_device, ema_device):
with torch.no_grad():
original_model_on_proper_device = model
need_to_delete_original = False
if ema_device != default_device:
original_model_on_other_device = deepcopy(model)
original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype)
del original_model_on_other_device
need_to_delete_original = True
params = dict(original_model_on_proper_device.named_parameters())
ema_params = dict(ema_model.named_parameters())
for name in ema_params:
#ema_params[name].data.mul_(decay).add_(params[name].data, alpha=1 - decay)
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
if need_to_delete_original:
del(original_model_on_proper_device)
def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
minimal_value = 1e-9
alphas_cumprod = noise_scheduler.alphas_cumprod
# Use .any() to check if any elements in the tensor are zero
if (alphas_cumprod[:-1] == 0).any():
logging.warning(
f"Alphas cumprod has zero elements! Resetting to {minimal_value}.."
)
alphas_cumprod[alphas_cumprod[:-1] == 0] = minimal_value
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
timesteps
].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
device=timesteps.device
)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR, first without epsilon
snr = (alpha / sigma) ** 2
# Check if the first element in SNR tensor is zero
if torch.any(snr == 0):
snr[snr == 0] = minimal_value
return snr
def load_train_json_from_file(args, report_load = False):
try:
if report_load:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
read_json = json.load(f)
args.__dict__.update(read_json)
except Exception as config_read:
print(f"Error on loading training config from {args.config}.")
def main(args):
"""
@ -384,36 +465,37 @@ def main(args):
device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
def release_memory(model_to_delete, original_device):
del model_to_delete
gc.collect()
if 'cuda' in original_device.type:
torch.cuda.empty_cache()
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
"""
Save the model to disk
"""
global global_step
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
logging.info(f" * Saving diffusers model to {save_path}")
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
if save_ckpt:
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
nonlocal save_ckpt_dir
nonlocal save_full_precision
nonlocal yaml_name
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
@ -423,17 +505,78 @@ def main(args):
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
global global_step
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
if args.ema_decay_rate != None:
pipeline_ema = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder_ema,
tokenizer=tokenizer,
unet=unet_ema,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
diffusers_model_path = save_path + "_ema"
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
pipeline_ema.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
unet=unet,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
diffusers_model_path = save_path
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
pipeline.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
if save_optimizer_flag:
logging.info(f" Saving optimizer state to {save_path}")
ed_optimizer.save(save_path)
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
ema_decay_model_loaded_from_file = False
if use_ema_dacay_training:
ema_device = torch.device(args.ema_decay_device)
optimizer_state_path = None
try:
# check for a local file
@ -458,16 +601,51 @@ def main(args):
unet = pipe.unet
del pipe
# leave the inference scheduler alone
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
if use_ema_dacay_training and args.ema_decay_resume_model:
print(f"Loading EMA model: {args.ema_decay_resume_model}")
ema_decay_model_loaded_from_file=True
hf_cache_path = get_hf_ckpt_cache_path(args.ema_decay_resume_model)
if args.zero_frequency_noise_ratio == -1.0:
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
if os.path.exists(hf_cache_path) or os.path.exists(args.ema_decay_resume_model):
ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt)
text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder")
unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet")
else:
# try to download from HF using ema_decay_resume_model as a repo id
ema_downloaded = try_download_model_from_hf(repo_id=args.ema_decay_resume_model)
if ema_downloaded is None:
raise ValueError(
f"No local file/folder for ema_decay_resume_model {args.ema_decay_resume_model}, and no matching huggingface.co repo could be downloaded")
ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded
text_encoder_ema = ema_pipe.text_encoder
unet_ema = ema_pipe.unet
del ema_pipe
# Make sure EMA model is on proper device, and memory released if moved
unet_ema_current_device = next(unet_ema.parameters()).device
if ema_device != unet_ema_current_device:
unet_ema_on_wrong_device = unet_ema
unet_ema = unet_ema.to(ema_device)
release_memory(unet_ema_on_wrong_device, unet_ema_current_device)
# Make sure EMA model is on proper device, and memory released if moved
text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device
if ema_device != text_encoder_ema_current_device:
text_encoder_ema_on_wrong_device = text_encoder_ema
text_encoder_ema = text_encoder_ema.to(ema_device)
release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device)
if args.enable_zero_terminal_snr:
# Use zero terminal SNR
from utils.unet_utils import enforce_zero_terminal_snr
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
@ -504,6 +682,30 @@ def main(args):
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
if use_ema_dacay_training:
if not ema_decay_model_loaded_from_file:
logging.info(f"EMA decay enabled, creating EMA model.")
with torch.no_grad():
if args.ema_decay_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
else:
# Make sure correct types are used for models
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
try:
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
@ -575,6 +777,20 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
if use_ema_dacay_training:
if args.ema_decay_target != None:
# Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.
total_number_of_steps: float = epoch_len * args.max_epochs
total_number_of_ema_update: float = total_number_of_steps / args.ema_decay_interval
args.ema_decay_rate = args.ema_decay_target ** (1 / total_number_of_ema_update)
logging.info(f"ema_decay_target is {args.ema_decay_target}, calculated ema_decay_rate will be: {args.ema_decay_rate}.")
logging.info(
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_decay_interval: {args.ema_decay_interval}, ema_decay_device: {args.ema_decay_device}.")
ed_optimizer = EveryDreamOptimizer(args,
optimizer_config,
text_encoder,
@ -590,7 +806,7 @@ def main(args):
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers,
use_penultimate_clip_layer=(args.clip_skip >= 2),
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
)
"""
@ -622,7 +838,9 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(interrupted_checkpoint_path, tokenizer, noise_scheduler, vae,
ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
exit(_SIGTERM_EXIT_CODE)
else:
# non-main threads (i.e. dataloader workers) should exit cleanly
@ -670,7 +888,7 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
@ -678,12 +896,13 @@ def main(args):
del pixel_values
latents = latents[0].sample() * 0.18215
if zero_frequency_noise_ratio > 0.0:
if zero_frequency_noise_ratio != None:
if zero_frequency_noise_ratio < 0:
zero_frequency_noise_ratio = 0
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
noise = torch.randn_like(latents) + zero_frequency_noise
else:
noise = torch.randn_like(latents)
bsz = latents.shape[0]
@ -714,9 +933,35 @@ def main(args):
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
return model_pred, target
if return_loss:
if args.min_snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
snr = compute_snr(timesteps, noise_scheduler)
mse_loss_weights = (
torch.stack(
[snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
mse_loss_weights[snr == 0] = 1.0
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
return model_pred, target, loss
else:
return model_pred, target
def generate_samples(global_step: int, batch):
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config()
@ -725,20 +970,76 @@ def main(args):
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=inference_scheduler.config,
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
del inference_pipe
gc.collect()
models_info = []
if (args.ema_decay_rate is None) or args.ema_decay_sample_raw_training:
models_info.append({"is_ema": False, "swap_required": False})
if (args.ema_decay_rate is not None) and args.ema_decay_sample_ema_model:
models_info.append({"is_ema": True, "swap_required": ema_device != device})
for model_info in models_info:
extra_info: str = ""
if model_info["is_ema"]:
current_unet, current_text_encoder = unet_ema, text_encoder_ema
extra_info = "_ema"
else:
current_unet, current_text_encoder = unet, text_encoder
torch.cuda.empty_cache()
if model_info["swap_required"]:
with torch.no_grad():
unet_unloaded = unet.to(ema_device)
del unet
text_encoder_unloaded = text_encoder.to(ema_device)
del text_encoder
current_unet = unet_ema.to(device)
del unet_ema
current_text_encoder = text_encoder_ema.to(device)
del text_encoder_ema
gc.collect()
torch.cuda.empty_cache()
inference_pipe = sample_generator.create_inference_pipe(unet=current_unet,
text_encoder=current_text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=inference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info)
# Cleanup
del inference_pipe
if model_info["swap_required"]:
with torch.no_grad():
unet = unet_unloaded.to(device)
del unet_unloaded
text_encoder = text_encoder_unloaded.to(device)
del text_encoder_unloaded
unet_ema = current_unet.to(ema_device)
del current_unet
text_encoder_ema = current_text_encoder.to(ema_device)
del current_text_encoder
gc.collect()
torch.cuda.empty_cache()
def make_save_path(epoch, global_step, prepend=""):
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation(global_step=0,
@ -755,7 +1056,7 @@ def main(args):
else:
logging.info("No plugins specified")
plugins = []
from plugins.plugins import PluginRunner
plugin_runner = PluginRunner(plugins=plugins)
@ -763,6 +1064,11 @@ def main(args):
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
for epoch in range(args.max_epochs):
if args.load_settings_every_epoch:
load_train_json_from_file(args)
plugin_runner.run_on_epoch_start(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
@ -790,9 +1096,7 @@ def main(args):
log_folder=log_folder,
batch=batch)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
del target, model_pred
@ -802,6 +1106,21 @@ def main(args):
ed_optimizer.step(loss, step, global_step)
if args.ema_decay_rate != None:
if ((global_step + 1) % args.ema_decay_interval) == 0:
# debug_start_time = time.time() # Measure time
if args.disable_unet_training != True:
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
if args.disable_textenc_training != True:
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
# debug_end_time = time.time() # Measure time
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
loss_step = loss.detach().item()
steps_pbar.set_postfix({"loss/step": loss_step}, {"gs": global_step})
@ -818,7 +1137,7 @@ def main(args):
lr_unet = ed_optimizer.get_unet_lr()
lr_textenc = ed_optimizer.get_textenc_lr()
loss_log_step = []
log_writer.add_scalar(tag="hyperparameter/lr unet", scalar_value=lr_unet, global_step=global_step)
log_writer.add_scalar(tag="hyperparameter/lr text encoder", scalar_value=lr_textenc, global_step=global_step)
log_writer.add_scalar(tag="loss/log_step", scalar_value=loss_step, global_step=global_step)
@ -846,12 +1165,16 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
plugin_runner.run_on_step_end(epoch=epoch,
global_step=global_step,
@ -884,13 +1207,14 @@ def main(args):
log_folder=log_folder,
data_root=args.data_root)
gc.collect()
gc.collect()
# end of epoch
# end of training
epoch = args.max_epochs
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -900,7 +1224,8 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = make_save_path(epoch, global_step, prepend="errored-")
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
raise ex
@ -916,14 +1241,7 @@ if __name__ == "__main__":
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, argv = argparser.parse_known_args()
if args.config is not None:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
args.__dict__.update(json.load(f))
if len(argv) > 0:
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
else:
print("No config file specified, using command line args")
load_train_json_from_file(args, report_load=True)
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
@ -972,9 +1290,20 @@ if __name__ == "__main__":
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide ema_decay_rate.")
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.")
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.")
argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training")
argparser.add_argument("--ema_decay_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay")
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)
main(args)

View File

@ -39,5 +39,15 @@
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02
"zero_frequency_noise_ratio": 0.02,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_decay_target": null,
"ema_decay_interval": null,
"ema_decay_device": null,
"ema_decay_sample_raw_training": false,
"ema_decay_sample_ema_model": false,
"ema_decay_resume_model" : null
}

View File

@ -181,7 +181,7 @@ class SampleGenerator:
self.sample_requests = self._make_random_caption_sample_requests()
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""):
"""
generates samples at different cfg scales and saves them to disk
"""
@ -269,15 +269,15 @@ class SampleGenerator:
prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step)
else:
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1
del result