Merge pull request #224 from a-l-e-x-d-s-9/main

EMA decay, min-SNR-gamma, settings loading every epoch, zero terminal SNR separation from ZFN.
This commit is contained in:
Victor Hall 2023-09-12 15:36:27 -04:00 committed by GitHub
commit 49de395df1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 453 additions and 74 deletions

View File

@ -151,10 +151,6 @@ Test results: https://huggingface.co/panopstor/ff7r-stable-diffusion/blob/main/z
Very tentatively, I suggest closer to 0.10 for short term training, and lower values of around 0.02 to 0.03 for longer runs (50k+ steps). Early indications seem to suggest values like 0.10 can cause divergance over time.
## Zero terminal SNR
Set `zero_frequency_noise_ratio` to -1.
## Keeping images together (custom batching)
If you have a subset of your dataset that expresses the same style or concept, training quality may be improved by putting all of these images through the trainer together in the same batch or batches, instead of the default behaviour (which is to shuffle them randomly throughout the entire dataset).
@ -205,3 +201,37 @@ Clips the gradient normals to a maximum value. Default is None (no clipping).
Default is no gradient normal clipping. There are also other ways to deal with gradient explosion, such as increasing optimizer epsilon.
## Zero Terminal SNR
**Parameter:** `--enable_zero_terminal_snr`
**Default:** `False`
To enable zero terminal SNR.
## Dynamic Configuration Loading
**Parameter:** `--load_settings_every_epoch`
**Default:** `False`
Most of the parameters in the train.json file CANNOT be modified during training. Activate this to have the `train.json` configuration file reloaded at the start of each epoch. The following parameter can be changed and will be applied after the start of a new epoch:
- `--save_every_n_epochs`
- `--save_ckpts_from_n_epochs`
- `--save_full_precision`
- `--save_optimizer`
- `--zero_frequency_noise_ratio`
- `--min_snr_gamma`
- `--clip_skip`
## Min-SNR-Gamma Parameter
**Parameter:** `--min_snr_gamma`
**Recommended Values:** 5, 1, 20
**Default:** `None`
To enable min-SNR-Gamma. For an in-depth understanding, consult this [research paper](https://arxiv.org/abs/2303.09556).
## EMA Decay Features
The Exponential Moving Average (EMA) model is copied from the base model at the start and is updated every interval of steps by a small contribution from training.
In this mode, the EMA model will be saved alongside the regular checkpoint from training. Normal training checkpoint can be loaded with `--resume_ckpt`, and the EMA model can be loaded with `--ema_decay_resume_model`.
**Parameters:**
- `--ema_decay_rate`: Determines the EMA decay rate. It defines how much the EMA model is updated from training at each update. Values should be close to 1 but not exceed it. Activating this parameter triggers the EMA decay feature.
- `--ema_decay_target`: Set the EMA decay target value within the (0,1) range. The `ema_decay_rate` is computed based on the relation: decay_rate to the power of (total_steps/decay_interval) equals decay_target. Enabling this parameter will override `ema_decay_rate` and will enable EMA decay feature.
- `--ema_decay_interval`: Set the interval in steps between EMA decay updates. The update occurs at each `global_steps` modulo `decay_interval`.
- `--ema_decay_device`: Choose between `cpu` and `cuda` for EMA decay. Opting for 'cpu' takes around 4 seconds per update and uses approximately 3.2GB RAM, while 'cuda' is much faster but requires a similar amount of VRAM.
- `--ema_decay_sample_raw_training`: Activate to display samples from the trained model, mirroring conventional training. They will not be presented by default with EMA decay enabled.
- `--ema_decay_sample_ema_model`: Turn on to exhibit samples from the EMA model. EMA models will be used for samples generations by default with EMA decay enabled, unless disabled.
- `--ema_decay_resume_model`: Indicate the EMA decay checkpoint to continue from, working like `--resume_ckpt` but will load EMA model. Using `findlast` will only load EMA version and not regular training.

View File

@ -40,5 +40,15 @@
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02
"zero_frequency_noise_ratio": 0.02,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_decay_target": null,
"ema_decay_interval": null,
"ema_decay_device": null,
"ema_decay_sample_raw_training": false,
"ema_decay_sample_ema_model": false,
"ema_decay_resume_model" : null
}

425
train.py
View File

@ -61,6 +61,7 @@ from utils.convert_diff_to_ckpt import convert as converter
from utils.isolate_rng import isolate_rng
from utils.check_git import check_git
from optimizer.optimizers import EveryDreamOptimizer
from copy import deepcopy
if torch.cuda.is_available():
from utils.gpu import GPU
@ -186,7 +187,7 @@ def set_args_12gb(args):
logging.info(" - Overiding resolution to max 512")
args.resolution = 512
def find_last_checkpoint(logdir):
def find_last_checkpoint(logdir, is_ema=False):
"""
Finds the last checkpoint in the logdir, recursively
"""
@ -196,6 +197,12 @@ def find_last_checkpoint(logdir):
for root, dirs, files in os.walk(logdir):
for file in files:
if os.path.basename(file) == "model_index.json":
if is_ema and (not root.endswith("_ema")):
continue
elif (not is_ema) and root.endswith("_ema"):
continue
curr_date = os.path.getmtime(os.path.join(root,file))
if last_date is None or curr_date > last_date:
@ -228,6 +235,11 @@ def setup_args(args):
# find the last checkpoint in the logdir
args.resume_ckpt = find_last_checkpoint(args.logdir)
if (args.ema_decay_resume_model != None) and (args.ema_decay_resume_model == "findlast"):
logging.info(f"{Fore.LIGHTCYAN_EX} Finding last EMA decay checkpoint in logdir: {args.logdir}{Style.RESET_ALL}")
args.ema_decay_resume_model = find_last_checkpoint(args.logdir, is_ema=True)
if args.lowvram:
set_args_12gb(args)
@ -356,6 +368,75 @@ def log_args(log_writer, args):
arglog += f"{arg}={value}, "
log_writer.add_text("config", arglog)
def update_ema(model, ema_model, decay, default_device, ema_device):
with torch.no_grad():
original_model_on_proper_device = model
need_to_delete_original = False
if ema_device != default_device:
original_model_on_other_device = deepcopy(model)
original_model_on_proper_device = original_model_on_other_device.to(ema_device, dtype=model.dtype)
del original_model_on_other_device
need_to_delete_original = True
params = dict(original_model_on_proper_device.named_parameters())
ema_params = dict(ema_model.named_parameters())
for name in ema_params:
#ema_params[name].data.mul_(decay).add_(params[name].data, alpha=1 - decay)
ema_params[name].data = ema_params[name] * decay + params[name].data * (1.0 - decay)
if need_to_delete_original:
del(original_model_on_proper_device)
def compute_snr(timesteps, noise_scheduler):
"""
Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
"""
minimal_value = 1e-9
alphas_cumprod = noise_scheduler.alphas_cumprod
# Use .any() to check if any elements in the tensor are zero
if (alphas_cumprod[:-1] == 0).any():
logging.warning(
f"Alphas cumprod has zero elements! Resetting to {minimal_value}.."
)
alphas_cumprod[alphas_cumprod[:-1] == 0] = minimal_value
sqrt_alphas_cumprod = alphas_cumprod**0.5
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
# Expand the tensors.
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[
timesteps
].float()
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(
device=timesteps.device
)[timesteps].float()
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
# Compute SNR, first without epsilon
snr = (alpha / sigma) ** 2
# Check if the first element in SNR tensor is zero
if torch.any(snr == 0):
snr[snr == 0] = minimal_value
return snr
def load_train_json_from_file(args, report_load = False):
try:
if report_load:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
read_json = json.load(f)
args.__dict__.update(read_json)
except Exception as config_read:
print(f"Error on loading training config from {args.config}.")
def main(args):
"""
@ -384,22 +465,86 @@ def main(args):
device = 'cpu'
gpu = None
log_folder = os.path.join(args.logdir, f"{args.project_name}_{log_time}")
if not os.path.exists(log_folder):
os.makedirs(log_folder)
def release_memory(model_to_delete, original_device):
del model_to_delete
gc.collect()
if 'cuda' in original_device.type:
torch.cuda.empty_cache()
@torch.no_grad()
def __save_model(save_path, unet, text_encoder, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
def __save_model(save_path, tokenizer, scheduler, vae, ed_optimizer, save_ckpt_dir, yaml_name,
save_full_precision=False, save_optimizer_flag=False, save_ckpt=True):
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
"""
Save the model to disk
"""
def save_ckpt_file(diffusers_model_path, sd_ckpt_path):
nonlocal save_ckpt_dir
nonlocal save_full_precision
nonlocal yaml_name
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
half = not save_full_precision
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=diffusers_model_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(diffusers_model_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
global global_step
if global_step is None or global_step == 0:
logging.warning(" No model to save, something likely blew up on startup, not saving")
return
logging.info(f" * Saving diffusers model to {save_path}")
if args.ema_decay_rate != None:
pipeline_ema = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder_ema,
tokenizer=tokenizer,
unet=unet_ema,
scheduler=scheduler,
safety_checker=None, # save vram
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
diffusers_model_path = save_path + "_ema"
logging.info(f" * Saving diffusers EMA model to {diffusers_model_path}")
pipeline_ema.save_pretrained(diffusers_model_path)
if save_ckpt:
sd_ckpt_path_ema = f"{os.path.basename(save_path)}_ema.ckpt"
save_ckpt_file(diffusers_model_path, sd_ckpt_path_ema)
pipeline = StableDiffusionPipeline(
vae=vae,
text_encoder=text_encoder,
@ -410,30 +555,28 @@ def main(args):
requires_safety_checker=None, # avoid nag
feature_extractor=None, # must be none of no safety checker
)
pipeline.save_pretrained(save_path)
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
diffusers_model_path = save_path
logging.info(f" * Saving diffusers model to {diffusers_model_path}")
pipeline.save_pretrained(diffusers_model_path)
if save_ckpt:
if save_ckpt_dir is not None:
sd_ckpt_full = os.path.join(save_ckpt_dir, sd_ckpt_path)
else:
sd_ckpt_full = os.path.join(os.curdir, sd_ckpt_path)
save_ckpt_dir = os.curdir
sd_ckpt_path = f"{os.path.basename(save_path)}.ckpt"
half = not save_full_precision
save_ckpt_file(diffusers_model_path, sd_ckpt_path)
logging.info(f" * Saving SD model to {sd_ckpt_full}")
converter(model_path=save_path, checkpoint_path=sd_ckpt_full, half=half)
if yaml_name and yaml_name != "v1-inference.yaml":
yaml_save_path = f"{os.path.join(save_ckpt_dir, os.path.basename(save_path))}.yaml"
logging.info(f" * Saving yaml to {yaml_save_path}")
shutil.copyfile(yaml_name, yaml_save_path)
if save_optimizer_flag:
logging.info(f" Saving optimizer state to {save_path}")
ed_optimizer.save(save_path)
use_ema_dacay_training = (args.ema_decay_rate != None) or (args.ema_decay_target != None)
ema_decay_model_loaded_from_file = False
if use_ema_dacay_training:
ema_device = torch.device(args.ema_decay_device)
optimizer_state_path = None
try:
# check for a local file
@ -458,16 +601,51 @@ def main(args):
unet = pipe.unet
del pipe
# leave the inference scheduler alone
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
if use_ema_dacay_training and args.ema_decay_resume_model:
print(f"Loading EMA model: {args.ema_decay_resume_model}")
ema_decay_model_loaded_from_file=True
hf_cache_path = get_hf_ckpt_cache_path(args.ema_decay_resume_model)
if args.zero_frequency_noise_ratio == -1.0:
# use zero terminal SNR, currently backdoor way to enable it by setting ZFN to -1, still in testing
if os.path.exists(hf_cache_path) or os.path.exists(args.ema_decay_resume_model):
ema_model_root_folder, ema_is_sd1attn, ema_yaml = convert_to_hf(args.resume_ckpt)
text_encoder_ema = CLIPTextModel.from_pretrained(ema_model_root_folder, subfolder="text_encoder")
unet_ema = UNet2DConditionModel.from_pretrained(ema_model_root_folder, subfolder="unet")
else:
# try to download from HF using ema_decay_resume_model as a repo id
ema_downloaded = try_download_model_from_hf(repo_id=args.ema_decay_resume_model)
if ema_downloaded is None:
raise ValueError(
f"No local file/folder for ema_decay_resume_model {args.ema_decay_resume_model}, and no matching huggingface.co repo could be downloaded")
ema_pipe, ema_model_root_folder, ema_is_sd1attn, ema_yaml = ema_downloaded
text_encoder_ema = ema_pipe.text_encoder
unet_ema = ema_pipe.unet
del ema_pipe
# Make sure EMA model is on proper device, and memory released if moved
unet_ema_current_device = next(unet_ema.parameters()).device
if ema_device != unet_ema_current_device:
unet_ema_on_wrong_device = unet_ema
unet_ema = unet_ema.to(ema_device)
release_memory(unet_ema_on_wrong_device, unet_ema_current_device)
# Make sure EMA model is on proper device, and memory released if moved
text_encoder_ema_current_device = next(text_encoder_ema.parameters()).device
if ema_device != text_encoder_ema_current_device:
text_encoder_ema_on_wrong_device = text_encoder_ema
text_encoder_ema = text_encoder_ema.to(ema_device)
release_memory(text_encoder_ema_on_wrong_device, text_encoder_ema_current_device)
if args.enable_zero_terminal_snr:
# Use zero terminal SNR
from utils.unet_utils import enforce_zero_terminal_snr
temp_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
trained_betas = enforce_zero_terminal_snr(temp_scheduler.betas).numpy().tolist()
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler", trained_betas=trained_betas)
else:
inference_scheduler = DDIMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
noise_scheduler = DDPMScheduler.from_pretrained(model_root_folder, subfolder="scheduler")
tokenizer = CLIPTokenizer.from_pretrained(model_root_folder, subfolder="tokenizer", use_fast=False)
@ -504,6 +682,30 @@ def main(args):
else:
text_encoder = text_encoder.to(device, dtype=torch.float32)
if use_ema_dacay_training:
if not ema_decay_model_loaded_from_file:
logging.info(f"EMA decay enabled, creating EMA model.")
with torch.no_grad():
if args.ema_decay_device == device:
unet_ema = deepcopy(unet)
text_encoder_ema = deepcopy(text_encoder)
else:
unet_ema_first = deepcopy(unet)
text_encoder_ema_first = deepcopy(text_encoder)
unet_ema = unet_ema_first.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema_first.to(ema_device, dtype=text_encoder.dtype)
del unet_ema_first
del text_encoder_ema_first
else:
# Make sure correct types are used for models
unet_ema = unet_ema.to(ema_device, dtype=unet.dtype)
text_encoder_ema = text_encoder_ema.to(ema_device, dtype=text_encoder.dtype)
try:
#unet = torch.compile(unet)
#text_encoder = torch.compile(text_encoder)
@ -575,6 +777,20 @@ def main(args):
epoch_len = math.ceil(len(train_batch) / args.batch_size)
if use_ema_dacay_training:
if args.ema_decay_target != None:
# Value in range (0,1). decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide decay_target.
total_number_of_steps: float = epoch_len * args.max_epochs
total_number_of_ema_update: float = total_number_of_steps / args.ema_decay_interval
args.ema_decay_rate = args.ema_decay_target ** (1 / total_number_of_ema_update)
logging.info(f"ema_decay_target is {args.ema_decay_target}, calculated ema_decay_rate will be: {args.ema_decay_rate}.")
logging.info(
f"EMA decay enabled, with ema_decay_rate {args.ema_decay_rate}, ema_decay_interval: {args.ema_decay_interval}, ema_decay_device: {args.ema_decay_device}.")
ed_optimizer = EveryDreamOptimizer(args,
optimizer_config,
text_encoder,
@ -590,7 +806,7 @@ def main(args):
default_sample_steps=args.sample_steps,
use_xformers=is_xformers_available() and not args.disable_xformers,
use_penultimate_clip_layer=(args.clip_skip >= 2),
guidance_rescale = 0.7 if args.zero_frequency_noise_ratio == -1 else 0
guidance_rescale=0.7 if args.enable_zero_terminal_snr else 0
)
"""
@ -622,7 +838,9 @@ def main(args):
logging.error(f"{Fore.LIGHTRED_EX} CTRL-C received, attempting to save model to {interrupted_checkpoint_path}{Style.RESET_ALL}")
logging.error(f"{Fore.LIGHTRED_EX} ************************************************************************{Style.RESET_ALL}")
time.sleep(2) # give opportunity to ctrl-C again to cancel save
__save_model(interrupted_checkpoint_path, unet, text_encoder, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(interrupted_checkpoint_path, tokenizer, noise_scheduler, vae,
ed_optimizer, args.save_ckpt_dir, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
exit(_SIGTERM_EXIT_CODE)
else:
# non-main threads (i.e. dataloader workers) should exit cleanly
@ -670,7 +888,7 @@ def main(args):
assert len(train_batch) > 0, "train_batch is empty, check that your data_root is correct"
# actual prediction function - shared between train and validate
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0):
def get_model_prediction_and_target(image, tokens, zero_frequency_noise_ratio=0.0, return_loss=False):
with torch.no_grad():
with autocast(enabled=args.amp):
pixel_values = image.to(memory_format=torch.contiguous_format).to(unet.device)
@ -678,12 +896,13 @@ def main(args):
del pixel_values
latents = latents[0].sample() * 0.18215
if zero_frequency_noise_ratio > 0.0:
if zero_frequency_noise_ratio != None:
if zero_frequency_noise_ratio < 0:
zero_frequency_noise_ratio = 0
# see https://www.crosslabs.org//blog/diffusion-with-offset-noise
zero_frequency_noise = zero_frequency_noise_ratio * torch.randn(latents.shape[0], latents.shape[1], 1, 1, device=latents.device)
noise = torch.randn_like(latents) + zero_frequency_noise
else:
noise = torch.randn_like(latents)
bsz = latents.shape[0]
@ -714,9 +933,35 @@ def main(args):
#print(f"types: {type(noisy_latents)} {type(timesteps)} {type(encoder_hidden_states)}")
model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
if return_loss:
if args.min_snr_gamma is None:
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
else:
snr = compute_snr(timesteps, noise_scheduler)
mse_loss_weights = (
torch.stack(
[snr, args.min_snr_gamma * torch.ones_like(timesteps)], dim=1
).min(dim=1)[0]
/ snr
)
mse_loss_weights[snr == 0] = 1.0
loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
loss = loss.mean(dim=list(range(1, len(loss.shape)))) * mse_loss_weights
loss = loss.mean()
return model_pred, target, loss
else:
return model_pred, target
def generate_samples(global_step: int, batch):
nonlocal unet
nonlocal text_encoder
nonlocal unet_ema
nonlocal text_encoder_ema
with isolate_rng():
prev_sample_steps = sample_generator.sample_steps
sample_generator.reload_config()
@ -725,20 +970,76 @@ def main(args):
print(f" * SampleGenerator config changed, now generating images samples every " +
f"{sample_generator.sample_steps} training steps (next={next_sample_step})")
sample_generator.update_random_captions(batch["captions"])
inference_pipe = sample_generator.create_inference_pipe(unet=unet,
text_encoder=text_encoder,
models_info = []
if (args.ema_decay_rate is None) or args.ema_decay_sample_raw_training:
models_info.append({"is_ema": False, "swap_required": False})
if (args.ema_decay_rate is not None) and args.ema_decay_sample_ema_model:
models_info.append({"is_ema": True, "swap_required": ema_device != device})
for model_info in models_info:
extra_info: str = ""
if model_info["is_ema"]:
current_unet, current_text_encoder = unet_ema, text_encoder_ema
extra_info = "_ema"
else:
current_unet, current_text_encoder = unet, text_encoder
torch.cuda.empty_cache()
if model_info["swap_required"]:
with torch.no_grad():
unet_unloaded = unet.to(ema_device)
del unet
text_encoder_unloaded = text_encoder.to(ema_device)
del text_encoder
current_unet = unet_ema.to(device)
del unet_ema
current_text_encoder = text_encoder_ema.to(device)
del text_encoder_ema
gc.collect()
torch.cuda.empty_cache()
inference_pipe = sample_generator.create_inference_pipe(unet=current_unet,
text_encoder=current_text_encoder,
tokenizer=tokenizer,
vae=vae,
diffusers_scheduler_config=inference_scheduler.config,
diffusers_scheduler_config=inference_scheduler.config
).to(device)
sample_generator.generate_samples(inference_pipe, global_step)
sample_generator.generate_samples(inference_pipe, global_step, extra_info=extra_info)
# Cleanup
del inference_pipe
if model_info["swap_required"]:
with torch.no_grad():
unet = unet_unloaded.to(device)
del unet_unloaded
text_encoder = text_encoder_unloaded.to(device)
del text_encoder_unloaded
unet_ema = current_unet.to(ema_device)
del current_unet
text_encoder_ema = current_text_encoder.to(ema_device)
del current_text_encoder
gc.collect()
torch.cuda.empty_cache()
def make_save_path(epoch, global_step, prepend=""):
return os.path.join(f"{log_folder}/ckpts/{prepend}{args.project_name}-ep{epoch:02}-gs{global_step:05}")
# Pre-train validation to establish a starting point on the loss graph
if validator:
validator.do_validation(global_step=0,
@ -763,6 +1064,11 @@ def main(args):
write_batch_schedule(args, log_folder, train_batch, epoch = 0)
for epoch in range(args.max_epochs):
if args.load_settings_every_epoch:
load_train_json_from_file(args)
plugin_runner.run_on_epoch_start(epoch=epoch,
global_step=global_step,
project_name=args.project_name,
@ -790,9 +1096,7 @@ def main(args):
log_folder=log_folder,
batch=batch)
model_pred, target = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio)
loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
model_pred, target, loss = get_model_prediction_and_target(batch["image"], batch["tokens"], args.zero_frequency_noise_ratio, return_loss=True)
del target, model_pred
@ -802,6 +1106,21 @@ def main(args):
ed_optimizer.step(loss, step, global_step)
if args.ema_decay_rate != None:
if ((global_step + 1) % args.ema_decay_interval) == 0:
# debug_start_time = time.time() # Measure time
if args.disable_unet_training != True:
update_ema(unet, unet_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
if args.disable_textenc_training != True:
update_ema(text_encoder, text_encoder_ema, args.ema_decay_rate, default_device=device, ema_device=ema_device)
# debug_end_time = time.time() # Measure time
# debug_elapsed_time = debug_end_time - debug_start_time # Measure time
# print(f"Command update_EMA unet and TE took {debug_elapsed_time:.3f} seconds.") # Measure time
loss_step = loss.detach().item()
steps_pbar.set_postfix({"loss/step": loss_step}, {"gs": global_step})
@ -846,12 +1165,16 @@ def main(args):
last_epoch_saved_time = time.time()
logging.info(f"Saving model, {args.ckpt_every_n_minutes} mins at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
if epoch > 0 and epoch % args.save_every_n_epochs == 0 and step == 0 and epoch < args.max_epochs - 1 and epoch >= args.save_ckpts_from_n_epochs:
logging.info(f" Saving model, {args.save_every_n_epochs} epochs at step {global_step}")
save_path = make_save_path(epoch, global_step)
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer,
args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer,
save_ckpt=not args.no_save_ckpt)
plugin_runner.run_on_step_end(epoch=epoch,
global_step=global_step,
@ -890,7 +1213,8 @@ def main(args):
# end of training
epoch = args.max_epochs
save_path = make_save_path(epoch, global_step, prepend=("" if args.no_prepend_last else "last-"))
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
total_elapsed_time = time.time() - training_start_time
logging.info(f"{Fore.CYAN}Training complete{Style.RESET_ALL}")
@ -900,7 +1224,8 @@ def main(args):
except Exception as ex:
logging.error(f"{Fore.LIGHTYELLOW_EX}Something went wrong, attempting to save model{Style.RESET_ALL}")
save_path = make_save_path(epoch, global_step, prepend="errored-")
__save_model(save_path, unet, text_encoder, tokenizer, inference_scheduler, vae, ed_optimizer, args.save_ckpt_dir, yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
__save_model(save_path, tokenizer, noise_scheduler, vae, ed_optimizer, args.save_ckpt_dir,
yaml, args.save_full_precision, args.save_optimizer, save_ckpt=not args.no_save_ckpt)
logging.info(f"{Fore.LIGHTYELLOW_EX}Model saved, re-raising exception and exiting. Exception was:{Style.RESET_ALL}{Fore.LIGHTRED_EX} {ex} {Style.RESET_ALL}")
raise ex
@ -916,14 +1241,7 @@ if __name__ == "__main__":
argparser.add_argument("--config", type=str, required=False, default=None, help="JSON config file to load options from")
args, argv = argparser.parse_known_args()
if args.config is not None:
print(f"Loading training config from {args.config}.")
with open(args.config, 'rt') as f:
args.__dict__.update(json.load(f))
if len(argv) > 0:
print(f"Config .json loaded but there are additional CLI arguments -- these will override values in {args.config}.")
else:
print("No config file specified, using command line args")
load_train_json_from_file(args, report_load=True)
argparser = argparse.ArgumentParser(description="EveryDream2 Training options")
argparser.add_argument("--amp", action="store_true", default=True, help="deprecated, use --disable_amp if you wish to disable AMP")
@ -972,7 +1290,18 @@ if __name__ == "__main__":
argparser.add_argument("--write_schedule", action="store_true", default=False, help="write schedule of images and their batches to file (def: False)")
argparser.add_argument("--rated_dataset", action="store_true", default=False, help="enable rated image set training, to less often train on lower rated images through the epochs")
argparser.add_argument("--rated_dataset_target_dropout_percent", type=int, default=50, help="how many images (in percent) should be included in the last epoch (Default 50)")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15, set to -1 to use zero terminal SNR noising beta schedule instead")
argparser.add_argument("--zero_frequency_noise_ratio", type=float, default=0.02, help="adds zero frequency noise, for improving contrast (def: 0.0) use 0.0 to 0.15")
argparser.add_argument("--enable_zero_terminal_snr", action="store_true", default=None, help="Use zero terminal SNR noising beta schedule")
argparser.add_argument("--load_settings_every_epoch", action="store_true", default=None, help="Will load 'train.json' at start of every epoch. Disabled by default and enabled when used.")
argparser.add_argument("--min_snr_gamma", type=int, default=None, help="min-SNR-gamma parameteris the loss function into individual tasks. Recommended values: 5, 1, 20. Disabled by default and enabled when used. More info: https://arxiv.org/abs/2303.09556")
argparser.add_argument("--ema_decay_rate", type=float, default=None, help="EMA decay rate. EMA model will be updated with (1 - ema_decay_rate) from training, and the ema_decay_rate from previous EMA, every interval. Values less than 1 and not so far from 1. Using this parameter will enable the feature.")
argparser.add_argument("--ema_decay_target", type=float, default=None, help="EMA decay target value in range (0,1). ema_decay_rate will be calculated from equation: decay_rate^(total_steps/decay_interval)=decay_target. Using this parameter will enable the feature and overide ema_decay_rate.")
argparser.add_argument("--ema_decay_interval", type=int, default=500, help="How many steps between every EMA decay update. EMA model will be update on every global_steps modulo decay_interval.")
argparser.add_argument("--ema_decay_device", type=str, default='cpu', help="EMA decay device values: cpu, cuda. Using 'cpu' is taking around 4 seconds per update vs fraction of a second on 'cuda'. Using 'cuda' will reserve around 3.2GB VRAM for a model, with 'cpu' RAM will be used.")
argparser.add_argument("--ema_decay_sample_raw_training", action="store_true", default=False, help="Will show samples from trained model, just like regular training. Can be used with: --ema_decay_sample_ema_model")
argparser.add_argument("--ema_decay_sample_ema_model", action="store_true", default=False, help="Will show samples from EMA model. Can be used with: --ema_decay_sample_raw_training")
argparser.add_argument("--ema_decay_resume_model", type=str, default=None, help="The EMA decay checkpoint to resume from for EMA decay, either a local .ckpt file, a converted Diffusers format folder, or a Huggingface.co repo id such as stabilityai/stable-diffusion-2-1-ema-decay")
# load CLI args to overwrite existing config args
args = argparser.parse_args(args=argv, namespace=args)

View File

@ -39,5 +39,15 @@
"write_schedule": false,
"rated_dataset": false,
"rated_dataset_target_dropout_percent": 50,
"zero_frequency_noise_ratio": 0.02
"zero_frequency_noise_ratio": 0.02,
"enable_zero_terminal_snr": false,
"load_settings_every_epoch": false,
"min_snr_gamma": null,
"ema_decay_rate": null,
"ema_decay_target": null,
"ema_decay_interval": null,
"ema_decay_device": null,
"ema_decay_sample_raw_training": false,
"ema_decay_sample_ema_model": false,
"ema_decay_resume_model" : null
}

View File

@ -181,7 +181,7 @@ class SampleGenerator:
self.sample_requests = self._make_random_caption_sample_requests()
@torch.no_grad()
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int):
def generate_samples(self, pipe: StableDiffusionPipeline, global_step: int, extra_info: str = ""):
"""
generates samples at different cfg scales and saves them to disk
"""
@ -269,15 +269,15 @@ class SampleGenerator:
prompt = prompts[prompt_idx]
clean_prompt = clean_filename(prompt)
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
result.save(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.jpg", format="JPEG", quality=95, optimize=True, progressive=False)
with open(f"{self.log_folder}/samples/gs{global_step:05}-{sample_index}-{extra_info}{clean_prompt[:100]}.txt", "w", encoding='utf-8') as f:
f.write(str(batch[prompt_idx]))
tfimage = transforms.ToTensor()(result)
if batch[prompt_idx].wants_random_caption:
self.log_writer.add_image(tag=f"sample_{sample_index}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"sample_{sample_index}{extra_info}", img_tensor=tfimage, global_step=global_step)
else:
self.log_writer.add_image(tag=f"sample_{sample_index}_{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
self.log_writer.add_image(tag=f"sample_{sample_index}_{extra_info}{clean_prompt[:100]}", img_tensor=tfimage, global_step=global_step)
sample_index += 1
del result